From fbb28b5764283b9e3da7dfec15f7cf0b0367bedd Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 25 Dec 2023 14:20:56 +0100 Subject: [PATCH] v1 - getting gibberish --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 4 ++ awq/models/whisper.py | 77 +++++++++++++++++++++++++++++++++++++++ awq/quantize/quantizer.py | 11 +++++- examples/whisper_quant.py | 24 ++++++++++++ 6 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 awq/models/whisper.py create mode 100644 examples/whisper_quant.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index cdcc4724..bb338628 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -13,3 +13,4 @@ from .baichuan import BaichuanAWQForCausalLM from .llava import LlavaAWQForCausalLM from .mixtral import MixtralAWQForCausalLM +from .whisper import WhisperAWQForConditionalGeneration \ No newline at end of file diff --git a/awq/models/auto.py b/awq/models/auto.py index 3db64461..51ad2c6f 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -21,6 +21,7 @@ "qwen": QwenAWQForCausalLM, "baichuan": BaichuanAWQForCausalLM, "llava": LlavaAWQForCausalLM, + "whisper": WhisperAWQForConditionalGeneration } def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs): diff --git a/awq/models/base.py b/awq/models/base.py index 756e8a2f..503f122c 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -57,9 +57,12 @@ "qwen": "AutoModelForCausalLM", "baichuan": "AutoModelForCausalLM", "llava": "AutoModelForVision2Seq", + "whisper": "AutoModelForSpeechSeq2Seq" } class BaseAWQForCausalLM(nn.Module): + is_encoder_decoder = False + def __init__(self, model, model_type, is_quantized, config, quant_config, processor): super().__init__() self.model:PreTrainedModel = model @@ -69,6 +72,7 @@ def __init__(self, model, model_type, is_quantized, config, quant_config, proces self.config: PretrainedConfig = config self.quant_config: AwqConfig = quant_config self.processor: CLIPImageProcessor = processor + def to(self, device: str): return self.model.to(device) diff --git a/awq/models/whisper.py b/awq/models/whisper.py new file mode 100644 index 00000000..d6f37b7a --- /dev/null +++ b/awq/models/whisper.py @@ -0,0 +1,77 @@ +import tqdm +from typing import List, Tuple +from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import LlamaLikeBlock +from awq.modules.fused.model import LlamaLikeModel +from transformers.models.whisper.modeling_whisper import ( + WhisperDecoderLayer as OldWhisperDecoderLayer, +) +from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration as OldWhisperForConditionalGeneration +from awq.modules.fused.mlp import QuantFusedMLP +from awq.modules.fused.norm import FasterTransformerRMSNorm + +class WhisperAWQForConditionalGeneration(BaseAWQForCausalLM): + layer_type = "WhisperDecoderLayer" + max_new_tokens_key = "max_position_embeddings" + is_encoder_decoder = True + + def get_input_embeds(self): + return self.model.model.decoder.embed_tokens + + @staticmethod + def fuse_layers(model: OldWhisperForConditionalGeneration): + fuser = LlavaFuser(model) + fuser.fuse_transformer() + + @staticmethod + def get_model_layers(model: OldWhisperForConditionalGeneration): + return model.model.decoder.layers + + @staticmethod + def get_act_for_scaling(module: OldWhisperDecoderLayer): + return dict( + is_scalable=False + ) + + @staticmethod + def move_embed(model: OldWhisperForConditionalGeneration, device: str): + model.proj_out = model.get_output_embeddings().to(device) + + @staticmethod + def get_layers_for_scaling(module: OldWhisperDecoderLayer, input_feat, module_kwargs): + layers = [] + + # attention input + layers.append(dict( + prev_op=module.self_attn_layer_norm, + layers=[module.self_attn.q_proj, + module.self_attn.k_proj, module.self_attn.v_proj], + inp=input_feat['self_attn.q_proj'], + module2inspect=module.self_attn, kwargs=module_kwargs, + )) + + # attention out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.out_proj.weight.shape: + layers.append(dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.out_proj], + inp=input_feat['self_attn.out_proj'], + )) + + # linear 1 + layers.append(dict( + prev_op=module.final_layer_norm, + layers=[module.fc1], + inp=input_feat['fc1'], + )) + + layers.append(dict( + prev_op=module.fc1, + layers=[module.fc2], + inp=input_feat['fc2'], + )) + + return layers + diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index feb563eb..2ea07956 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -352,8 +352,15 @@ def forward(self, *args, **kwargs): # patch layer 0 to catch input and kwargs modules[0] = Catcher(modules[0]) + + try: - self.model(samples.to(next(self.model.parameters()).device)) + if not self.awq_model.is_encoder_decoder: + device = next(self.awq_model.parameters()).device + self.model(samples.to(device)) + else: + device = self.awq_model.get_input_embeds().weight.device + modules[0](self.awq_model.get_input_embeds()(samples.to(device))) except ValueError: # work with early exit pass @@ -361,7 +368,7 @@ def forward(self, *args, **kwargs): # that takes care of everything to avoid unexpected errors. layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs) # Pop the input_ids as they are not needed at all. - layer_kwargs.pop("input_ids") + layer_kwargs.pop("input_ids", None) del samples modules[0] = modules[0].module # restore diff --git a/examples/whisper_quant.py b/examples/whisper_quant.py new file mode 100644 index 00000000..82f4a6a4 --- /dev/null +++ b/examples/whisper_quant.py @@ -0,0 +1,24 @@ +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer + +model_path = 'openai/whisper-large-v3' +quant_path = 'whisper-large-awq' +modules_to_not_convert = ["encoder_attn", "encoder"] + +quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} + +# Load model +# NOTE: pass safetensors=True to load safetensors +model = AutoAWQForCausalLM.from_pretrained( + model_path, **{"low_cpu_mem_usage": True, "use_cache": False} +) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +# Quantize +model.quantize(tokenizer, quant_config=quant_config, modules_to_not_convert=modules_to_not_convert) + +# Save quantized model +model.save_quantized(quant_path) +tokenizer.save_pretrained(quant_path) + +print(f'Model is quantized and saved at "{quant_path}"') \ No newline at end of file