Skip to content

Commit

Permalink
refactor permute layout
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 15, 2024
1 parent 6fdc7bd commit 467848d
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 85 deletions.
3 changes: 2 additions & 1 deletion csrc/quantization/machete/machete_mm_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ struct MacheteKernelTemplate {

int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);

int const group_size = maybe_group_size.value_or(K);
int group_size = maybe_group_size.value_or(K);
group_size = (group_size == -1) ? K : group_size;
int const scale_k = (K + group_size - 1) / group_size;

TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.parameter import (ModelWeightParameter,
PackedvLLMParameter)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)

from .MPLinearKernel import *

Expand Down Expand Up @@ -46,19 +46,17 @@ def can_implement(cls,
def process_weights_after_loading(self, layer: torch.nn.Module):

def transform_w_q(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, PackedvLLMParameter):
x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0)
return ops.machete_prepack_B(x.t().contiguous().t(),
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
return x

def transform_w_s(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, ModelWeightParameter):
x = x.permute_layout(input_dim=0, output_dim=1)
return x.contiguous()
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x

# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
Expand All @@ -74,6 +72,9 @@ def apply_weights(self,
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )

print(w_s)
print(c.group_size)

output = ops.machete_gemm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
query_marlin_supported_quant_types)
from vllm.model_executor.parameter import (ModelWeightParameter,
PackedvLLMParameter)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)

from .MPLinearKernel import *

Expand Down Expand Up @@ -83,25 +83,23 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))

def transform_w_q(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, PackedvLLMParameter):
x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0)
return ops.gptq_marlin_repack(x.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits)
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits)
return x

def transform_w_s(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, ModelWeightParameter):
x = x.permute_layout(input_dim=0, output_dim=1)
return marlin_permute_scales(x.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size)
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size)
return x

self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
Expand Down
115 changes: 62 additions & 53 deletions vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,25 +145,6 @@ def __init__(self, input_dim: int, **kwargs):
self._input_dim = input_dim
super().__init__(**kwargs)

def permute_layout(self, input_dim: int, output_dim: int, **kwargs) \
-> 'ModelWeightParameter':

# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim
# preservier other dimensions
perm = [
i for i in range(self.data.dim())
if i not in [self.input_dim, self.output_dim]
]
perm.insert(input_dim, self.input_dim)
perm.insert(output_dim, self.output_dim)

return ModelWeightParameter(data=self.data.permute(*perm).contiguous(),
weight_loader=self.weight_loader,
input_dim=input_dim,
output_dim=output_dim,
**kwargs)

@property
def input_dim(self):
return self._input_dim
Expand Down Expand Up @@ -278,23 +259,6 @@ def __init__(self,
self._marlin_tile_size = marlin_tile_size
super().__init__(**kwargs)

def permute_layout(self, input_dim: int, output_dim: int,
packed_dim: int = 0,
**kwargs)\
-> 'PackedColumnParameter':

assert packed_dim == packed_dim

return PackedColumnParameter(
data=ModelWeightParameter\
.permute_layout(self, input_dim, output_dim).data,
weight_loader=self.weight_loader,
input_dim=input_dim,
output_dim=output_dim,
packed_dim=self.packed_dim,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)

@property
def packed_dim(self):
return self._packed_dim
Expand Down Expand Up @@ -336,23 +300,6 @@ def __init__(self,
self._marlin_tile_size = marlin_tile_size
super().__init__(**kwargs)

def permute_layout(self, input_dim: int, output_dim: int,
packed_dim: int = 0,
**kwargs)\
-> 'PackedvLLMParameter':

assert packed_dim == packed_dim

return PackedvLLMParameter(
data=ModelWeightParameter\
.permute_layout(self, input_dim, output_dim).data,
weight_loader=self.weight_loader,
input_dim=input_dim,
output_dim=output_dim,
packed_dim=self.packed_dim,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)

@property
def packed_dim(self):
return self._packed_dim
Expand All @@ -373,6 +320,68 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
marlin_tile_size=self.marlin_tile_size)


def permute_param_layout_(
param: BasevLLMParameter,
input_dim: int,
output_dim: int,
**kwargs
) -> BasevLLMParameter:
"""
Permute a parameter's layout to the specified input and output dimensions,
useful for forcing the parameter into a known layout, for example, if I need
a packed (quantized) weight matrix to be in the layout
{input_dim = 0, output_dim = 1, packed_dim = 0}
then I can call:
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
to ensure x is in the correct layout (permuting it to the correct layout if
required, asserting if it cannot get it to the correct layout)
"""

curr_input_dim = getattr(param, "input_dim", None)
curr_output_dim = getattr(param, "output_dim", None)

if curr_input_dim is None or curr_output_dim is None:
assert param.data.dim() == 2,\
"permute_param_layout_ only supports 2D parameters where either "\
"input_dim or output_dim is not set"

# if one of the dimensions is not set, set it to the opposite of the other
# we can only do this since we asserted the parameter is 2D above
if curr_input_dim is None:
assert curr_output_dim is not None,\
"either input or output dim must be set"
curr_input_dim = (curr_output_dim + 1) % 2
if curr_output_dim is None:
assert curr_input_dim is not None,\
"either input or output dim must be set"
curr_output_dim = (curr_input_dim + 1) % 2

# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim preserving
# other dimensions
perm = [
i for i in range(param.data.dim())
if i not in [curr_input_dim, curr_output_dim]
]
perm.insert(input_dim, curr_input_dim)
perm.insert(output_dim, curr_output_dim)

if "packed_dim" in kwargs:
assert hasattr(param, "packed_dim") and\
param.packed_dim == perm[kwargs["packed_dim"]],\
"permute_param_layout_ currently doesn't support repacking"

param.data = param.data.permute(*perm)
if hasattr(param, "_input_dim"):
param._input_dim = input_dim
if hasattr(param, "_output_dim"):
param._output_dim = output_dim
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
param._packed_dim = kwargs["packed_dim"]

return param


def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
marlin_tile_size):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
Expand Down

0 comments on commit 467848d

Please sign in to comment.