Skip to content

Commit

Permalink
add machete as a gptq backend
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 12, 2024
1 parent 6dcbf08 commit 0e9b4c2
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 135 deletions.
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_tensor
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
Expand Down Expand Up @@ -240,23 +240,23 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)
replace_parameter(layer, "qweight", marlin_qweight)

# Permute scales from AWQ format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales)
replace_parameter(layer, "scales", marlin_scales)

# Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points(
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qzeros", marlin_zp)
replace_parameter(layer, "qzeros", marlin_zp)

# Not-used
layer.g_idx = marlin_make_empty_g_idx(device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.scalar_type import scalar_types
Expand Down Expand Up @@ -68,7 +69,6 @@ def create_weights(self, layer: torch.nn.Module, output_size: int,
)

kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
self.kernel = kernel_type(mp_linear_kernel_config)

# If group_size is -1, we are in channelwise case.
channelwise = (self.group_size == -1)
Expand Down Expand Up @@ -125,9 +125,29 @@ def create_weights(self, layer: torch.nn.Module, output_size: int,
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)

self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name=None,
w_gidx_param_name=None)

# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# convert `weight_packed` from:
# {input_dim = 1, output_dim = 0, packed_dim = 1}
# to:
# {input_dim = 0, output_dim = 1, packed_dim = 0}
# expected the kernels `process_weights_after_loading`
replace_parameter(layer, "weight_packed", layer.weight_packed.t())

# convert `weight_scale` from:
# {input_dim = 1, output_dim = 0}
# to:
# {input_dim = 0, output_dim = 1}
# expected the kernels `process_weights_after_loading`
replace_parameter(layer, "weight_scale", layer.weight_scale.t())

self.kernel.process_weights_after_loading(layer)


Expand Down
86 changes: 27 additions & 59 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

from vllm import _custom_ops as ops
from vllm.logger import init_logger

from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)

from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_tensor
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
Expand Down Expand Up @@ -160,22 +163,28 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition

mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
act_reordering=self.quant_config.desc_act
)

kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)

# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size

verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size)

# Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
Expand Down Expand Up @@ -269,68 +278,27 @@ def create_weights(
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
is_row_parallel)

self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx")

# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device

# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)

# Handle sorting for activation reordering if needed.
if self.quant_config.desc_act:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "g_idx", g_idx)
else:
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)

# No zero-point
layer.zp = marlin_make_empty_g_idx(device)

# Repack weights from autogptq format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)

# Permute scales from autogptq format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=(layer.input_size if self.quant_config.desc_act else
layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales)
# `qweight` and `scales` are already in the correct format. So we can
# just call `process_weights_after_loading` right-away
self.kernel.process_weights_after_loading(layer)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_gptq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.zp,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
bias=bias)
return self.kernel.apply_weights(layer, x, bias)
43 changes: 37 additions & 6 deletions vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict, Callable

import torch

from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.scalar_type import ScalarType


Expand Down Expand Up @@ -31,18 +32,48 @@ def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
raise NotImplementedError

def __init__(self, c: MPLinearLayerConfig) -> None:
def __init__(self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name

# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `getattr(layer, w_q_name)` is:
# {input_dim = 0, output_dim = 1, packed_dim = 0}
# `getattr(layer, w_s_name)` is:
# {input_dim = 0, output_dim = 1}
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError

@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
raise NotImplementedError

def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
fn: Callable) -> None:
if name is not None and getattr(layer, name, None) is not None:
replace_parameter(layer, name, fn(getattr(layer, name)))

def _get_weight_params(
self, layer: torch.nn.Module
) -> Tuple[torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], Optional[torch.Tensor]]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_tensor
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
Expand Down Expand Up @@ -41,32 +41,29 @@ def can_implement(cls,

# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
# Repack weights and scales for Machete

replace_tensor(
layer, "weight_packed",
ops.machete_prepack_B(layer.weight_packed.t(),
self.config.weight_type))
replace_tensor(layer, "weight_scale",
layer.weight_scale.clone().contiguous().t())

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
self._transform_param(
layer, self.w_q_name, lambda x: ops.machete_prepack_B(
x.t().contiguous().t(), self.config.weight_type))
self._transform_param(layer, self.w_s_name, lambda x: x.contiguous())

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)

# print("running machete")
# print(layer.weight_packed.shape)
# print(layer.weight_scale.dtype, x.dtype, layer.group_size)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )

output = ops.machete_gemm(a=x_2d,
b_q=layer.weight_packed,
b_q=w_q,
b_type=c.weight_type,
b_zeros=None,
b_scales=layer.weight_scale,
b_scales=w_s,
b_group_size=c.group_size)

if bias is not None:
Expand Down
Loading

0 comments on commit 0e9b4c2

Please sign in to comment.