Skip to content

Commit

Permalink
Update model_base.py (#124)
Browse files Browse the repository at this point in the history
fix quantization config for layers
  • Loading branch information
YangWang92 authored Nov 18, 2024
1 parent 4adafff commit 139a380
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion vptq/layers/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

target_layer = VQuantLinear
quantization_config = auto_conf.quantization_config
config_for_layers = quantization_config['config_for_layers']

# replace linear layers with quantized linear layers
with transformers.utils.generic.ContextManagers([accelerate.init_empty_weights()]):
make_quant_linear(model, quantization_config, target_layer=target_layer)
make_quant_linear(model, config_for_layers, target_layer=target_layer)

no_split_module_classes = [i[1].__class__.__name__ for i in model.named_modules() if i[0].endswith(".0")]

Expand Down

0 comments on commit 139a380

Please sign in to comment.