From 0e9b4c281f3b08ffcfa7dfed87a96a60d31c308a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 6 Aug 2024 01:02:07 +0000 Subject: [PATCH] add machete as a gptq backend --- .../layers/quantization/awq_marlin.py | 8 +- .../schemes/compressed_tensors_wNa16.py | 22 +++- .../layers/quantization/gptq_marlin.py | 86 +++++---------- .../quantization/kernels/MPLinearKernel.py | 43 ++++++-- .../kernels/MacheteLinearKernel.py | 33 +++--- .../kernels/MarlinLinearKernel.py | 102 +++++++++++------- .../layers/quantization/utils/__init__.py | 4 +- .../layers/quantization/utils/layer_utils.py | 34 ++++-- .../layers/quantization/utils/marlin_utils.py | 5 + 9 files changed, 202 insertions(+), 135 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index fc259589d777d..91ef30b0874c3 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -9,7 +9,7 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.utils import replace_tensor +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, @@ -240,7 +240,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) + replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. marlin_scales = marlin_permute_scales( @@ -248,7 +248,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. marlin_zp = awq_to_marlin_zero_points( @@ -256,7 +256,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.num_groups, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qzeros", marlin_zp) + replace_parameter(layer, "qzeros", marlin_zp) # Not-used layer.g_idx = marlin_make_empty_g_idx(device) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index f086bce9666d3..50c7a1d3fbb20 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -8,6 +8,7 @@ ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedvLLMParameter) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.kernels import ( MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.scalar_type import scalar_types @@ -68,7 +69,6 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, ) kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) - self.kernel = kernel_type(mp_linear_kernel_config) # If group_size is -1, we are in channelwise case. channelwise = (self.group_size == -1) @@ -125,9 +125,29 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name=None) + # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # convert `weight_packed` from: + # {input_dim = 1, output_dim = 0, packed_dim = 1} + # to: + # {input_dim = 0, output_dim = 1, packed_dim = 0} + # expected the kernels `process_weights_after_loading` + replace_parameter(layer, "weight_packed", layer.weight_packed.t()) + + # convert `weight_scale` from: + # {input_dim = 1, output_dim = 0} + # to: + # {input_dim = 0, output_dim = 1} + # expected the kernels `process_weights_after_loading` + replace_parameter(layer, "weight_scale", layer.weight_scale.t()) + self.kernel.process_weights_after_loading(layer) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 7f62e50e70ec9..fd64a4dc9cb89 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -5,11 +5,14 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger + +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) + from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.utils import replace_tensor from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, @@ -160,22 +163,28 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: - del output_size output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + act_reordering=self.quant_config.desc_act + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) - # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, self.quant_config.group_size, @@ -269,11 +278,12 @@ def create_weights( layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, - is_row_parallel) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx") # Checkpoints are serialized in AutoGPTQ format, which is different from the # marlin format. This function is called after the weights are loaded. @@ -281,39 +291,9 @@ def create_weights( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = layer.qweight.device - # Allocate marlin workspace - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Handle sorting for activation reordering if needed. - if self.quant_config.desc_act: - g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "g_idx", g_idx) - else: - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.zp = marlin_make_empty_g_idx(device) - - # Repack weights from autogptq format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) - - # Permute scales from autogptq format to marlin format. - marlin_scales = marlin_permute_scales( - layer.scales, - size_k=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), - size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + # `qweight` and `scales` are already in the correct format. So we can + # just call `process_weights_after_loading` right-away + self.kernel.process_weights_after_loading(layer) def apply( self, @@ -321,16 +301,4 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return apply_gptq_marlin_linear( - input=x, - weight=layer.qweight, - weight_scale=layer.scales, - weight_zp=layer.zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_config.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=layer.is_k_full, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py index e1759e1429ef0..a8cf3c0d669d7 100644 --- a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict, Callable import torch +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.scalar_type import ScalarType @@ -31,18 +32,48 @@ def can_implement(cls, c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: raise NotImplementedError - def __init__(self, c: MPLinearLayerConfig) -> None: + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: assert self.can_implement(c) self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name # note assumes that - # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} - # `weight_scale` is: {input_dim = 0, output_dim = 1} + # `getattr(layer, w_q_name)` is: + # {input_dim = 0, output_dim = 1, packed_dim = 0} + # `getattr(layer, w_s_name)` is: + # {input_dim = 0, output_dim = 1} @abstractmethod def process_weights_after_loading(self, layer: torch.nn.Module) -> None: raise NotImplementedError @abstractmethod - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + replace_parameter(layer, name, fn(getattr(layer, name))) + + def _get_weight_params( + self, layer: torch.nn.Module + ) -> Tuple[torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], Optional[torch.Tensor]]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py index 11e79967206b8..d936549a912e6 100644 --- a/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py @@ -1,5 +1,5 @@ from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils import replace_tensor +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.machete_utils import ( MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, query_machete_supported_quant_types) @@ -41,32 +41,29 @@ def can_implement(cls, # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} - # `weight_scale` is: {input_dim = 0, output_dim = 1} - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): # Repack weights and scales for Machete - - replace_tensor( - layer, "weight_packed", - ops.machete_prepack_B(layer.weight_packed.t(), - self.config.weight_type)) - replace_tensor(layer, "weight_scale", - layer.weight_scale.clone().contiguous().t()) - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + self._transform_param( + layer, self.w_q_name, lambda x: ops.machete_prepack_B( + x.t().contiguous().t(), self.config.weight_type)) + self._transform_param(layer, self.w_s_name, lambda x: x.contiguous()) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) - # print("running machete") - # print(layer.weight_packed.shape) - # print(layer.weight_scale.dtype, x.dtype, layer.group_size) x_2d = x.reshape(-1, x.shape[-1]) out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) output = ops.machete_gemm(a=x_2d, - b_q=layer.weight_packed, + b_q=w_q, b_type=c.weight_type, b_zeros=None, - b_scales=layer.weight_scale, + b_scales=w_s, b_group_size=c.group_size) if bias is not None: diff --git a/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py index 632785c10c31f..83e31f85c48b7 100644 --- a/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py @@ -1,10 +1,9 @@ from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils import replace_tensor from vllm.model_executor.layers.quantization.utils.marlin_utils import ( MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, check_marlin_supports_shape, marlin_make_empty_g_idx, - marlin_make_workspace, marlin_permute_scales, - query_marlin_supported_quant_types) + marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + marlin_is_k_full, query_marlin_supported_quant_types) from .MPLinearKernel import * @@ -20,9 +19,8 @@ def can_implement(cls, c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: if c.zero_points: return False, "Zero points currently not supported by "\ - " Compressed Tensors + Marlin. (Kernel supports it"\ - " but CompressedTensorsWNA16 does not so support has"\ - " not been added to MarlinWNA16Kernel yet" + " MarlinLinearKernel. Will be added when AWQMarlin "\ + "is migrated over to using MPLinearKernel backend" quant_types = query_marlin_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: @@ -43,51 +41,79 @@ def can_implement(cls, # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.weight_packed.device + device = getattr(layer, self.w_q_name).device c = self.config + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.act_reordering, row_parallel) + # Allocate marlin workspace. self.workspace = marlin_make_workspace(c.partition_weight_shape[1], device) - # Act-order not supported in compressed-tensors yet, so set to empty. - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.weight_zp = marlin_make_empty_g_idx(device) - - # Repack weights from compressed-tensors format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed.t().contiguous(), - perm=layer.g_idx_sort_indices, - size_k=c.partition_weight_shape[0], - size_n=c.partition_weight_shape[1], - num_bits=c.weight_type.size_bits) - replace_tensor(layer, "weight_packed", marlin_qweight) - - # Permute scales from compressed-tensors format to marlin format. - marlin_scales = marlin_permute_scales( - layer.weight_scale.squeeze().t().contiguous(), - size_k=c.partition_weight_shape[0], - size_n=c.partition_weight_shape[1], - group_size=c.group_size) - replace_tensor(layer, "weight_scale", marlin_scales) - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + if c.act_reordering: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + pass + # TODO (lucas): add the following when AWQMarlin is migrated over to + # using MPLinearKernel backend + # self._transform_param(layer, self.w_zp_name, lambda x: \ + # marlin_zero_points( + # x, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + + self._transform_param(layer, self.w_q_name, lambda x: \ + ops.gptq_marlin_repack( + x.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits)) + + self._transform_param(layer, self.w_s_name, lambda x: \ + marlin_permute_scales( + x.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size)) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + # `process_weights_after_loading`` will ensure w_zp and w_gidx are not + # None for marlin return apply_gptq_marlin_linear( input=x, - weight=layer.weight_packed, - weight_scale=layer.weight_scale, - weight_zp=layer.weight_zp, - g_idx=layer.g_idx, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore g_idx_sort_indices=layer.g_idx_sort_indices, workspace=self.workspace, wtype=c.weight_type, input_size_per_partition=c.partition_weight_shape[0], output_size_per_partition=c.partition_weight_shape[1], - is_k_full=True, + is_k_full=self.is_k_full, bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py index 3a88224040530..e60f0c79ac1f7 100644 --- a/vllm/model_executor/layers/quantization/utils/__init__.py +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -1,3 +1,3 @@ -from .layer_utils import replace_tensor +from .layer_utils import replace_parameter, update_tensor_inplace -__all__ = ['replace_tensor'] +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py index 8345bd860ca29..e547a7577d43d 100644 --- a/vllm/model_executor/layers/quantization/utils/layer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -1,12 +1,32 @@ +from typing import Union import torch +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + # Newly generated tensors need to replace existing tensors that are # already registered as parameters by vLLM (and won't be freed) -def replace_tensor(layer: torch.nn.Module, name: str, - new_t: torch.Tensor) -> None: - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new) + mod.register_parameter(name, torch.nn.Parameter(new)) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index fb2fdb5ca28df..e83b4eacf8f38 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -159,6 +159,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: requires_grad=False) +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + def marlin_sort_g_idx( g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)