Skip to content

Commit

Permalink
update to use ModelWeightParameter
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 7, 2024
1 parent 2da121b commit 9696a55
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,6 @@ def create_weights(self, layer: torch.nn.Module, output_size: int,
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# convert `weight_packed` from:
# {input_dim = 1, output_dim = 0, packed_dim = 1}
# to:
# {input_dim = 0, output_dim = 1, packed_dim = 0}
# expected the kernels `process_weights_after_loading`
replace_parameter(layer, "weight_packed", layer.weight_packed.t())

# convert `weight_scale` from:
# {input_dim = 1, output_dim = 0}
# to:
# {input_dim = 0, output_dim = 1}
# expected the kernels `process_weights_after_loading`
replace_parameter(layer, "weight_scale", layer.weight_scale.t())

self.kernel.process_weights_after_loading(layer)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self,
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name

# note assumes that
# note assumes that (if the they are not ModelWeightParameters)
# `getattr(layer, w_q_name)` is:
# {input_dim = 0, output_dim = 1, packed_dim = 0}
# `getattr(layer, w_s_name)` is:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)

from vllm.model_executor.parameter import (ModelWeightParameter,
PackedvLLMParameter)
from .MPLinearKernel import *


Expand Down Expand Up @@ -43,11 +43,23 @@ def can_implement(cls,
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
def transform_w_q(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, PackedvLLMParameter):
x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0)
return ops.machete_prepack_B(x.t().contiguous().t(),
self.config.weight_type)
def transform_w_s(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, ModelWeightParameter):
x = x.permute_layout(input_dim=0, output_dim=1)
return x.contiguous()

# Repack weights and scales for Machete
self._transform_param(
layer, self.w_q_name, lambda x: ops.machete_prepack_B(
x.t().contiguous().t(), self.config.weight_type))
self._transform_param(layer, self.w_s_name, lambda x: x.contiguous())
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)

def apply_weights(self,
layer: torch.nn.Module,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
marlin_is_k_full, query_marlin_supported_quant_types)

from .MPLinearKernel import *

from vllm.model_executor.parameter import (ModelWeightParameter,
PackedvLLMParameter)

class MarlinLinearKernel(MPLinearKernel):

Expand Down Expand Up @@ -80,20 +81,31 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))

self._transform_param(layer, self.w_q_name, lambda x: \
ops.gptq_marlin_repack(
def transform_w_q(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, PackedvLLMParameter):
x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0)
return ops.gptq_marlin_repack(
x.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits))

self._transform_param(layer, self.w_s_name, lambda x: \
marlin_permute_scales(
num_bits=c.weight_type.size_bits)

def transform_w_s(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, ModelWeightParameter):
x = x.permute_layout(input_dim=0, output_dim=1)
return marlin_permute_scales(
x.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size))
group_size=c.group_size)

self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)

def apply_weights(self,
layer: torch.nn.Module,
Expand Down
36 changes: 36 additions & 0 deletions vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ class ModelWeightParameter(_ColumnvLLMParameter):
def __init__(self, input_dim: int, **kwargs):
self._input_dim = input_dim
super().__init__(**kwargs)

def permute_layout(self, input_dim: int, output_dim: int, **kwargs) \
-> 'ModelWeightParameter':

# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim
# preservier other dimensions
perm = [i for i in range(self.data.dim())
if i not in [self.input_dim, self.output_dim]
]
perm.insert(input_dim, self.input_dim)
perm.insert(output_dim, self.output_dim)

return ModelWeightParameter(
data=self.data.permute(*perm).contiguous(),
weight_loader=self.weight_loader,
input_dim=input_dim,
output_dim=output_dim,
**kwargs)

@property
def input_dim(self):
Expand Down Expand Up @@ -253,6 +272,23 @@ def __init__(self,
self._marlin_tile = marlin_tile_size
super().__init__(**kwargs)

def permute_layout(self, input_dim: int, output_dim: int,
packed_dim: int = 0,
**kwargs)\
-> 'ModelWeightParameter':

assert packed_dim == packed_dim

return PackedvLLMParameter(
data=ModelWeightParameter\
.permute_layout(self, input_dim, output_dim).data,
weight_loader=self.weight_loader,
input_dim=input_dim,
output_dim=output_dim,
packed_dim=self.packed_dim,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile)

@property
def packed_dim(self):
return self._packed_dim
Expand Down

0 comments on commit 9696a55

Please sign in to comment.