Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whiper] Add Whisper to AutoAWQ #282

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from .baichuan import BaichuanAWQForCausalLM
from .llava import LlavaAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM
from .whisper import WhisperAWQForConditionalGeneration
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"qwen": QwenAWQForCausalLM,
"baichuan": BaichuanAWQForCausalLM,
"llava": LlavaAWQForCausalLM,
"whisper": WhisperAWQForConditionalGeneration
}

def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs):
Expand Down
4 changes: 4 additions & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@
"qwen": "AutoModelForCausalLM",
"baichuan": "AutoModelForCausalLM",
"llava": "AutoModelForVision2Seq",
"whisper": "AutoModelForSpeechSeq2Seq"
}

class BaseAWQForCausalLM(nn.Module):
is_encoder_decoder = False

def __init__(self, model, model_type, is_quantized, config, quant_config, processor):
super().__init__()
self.model:PreTrainedModel = model
Expand All @@ -69,6 +72,7 @@ def __init__(self, model, model_type, is_quantized, config, quant_config, proces
self.config: PretrainedConfig = config
self.quant_config: AwqConfig = quant_config
self.processor: CLIPImageProcessor = processor


def to(self, device: str):
return self.model.to(device)
Expand Down
77 changes: 77 additions & 0 deletions awq/models/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.whisper.modeling_whisper import (
WhisperDecoderLayer as OldWhisperDecoderLayer,
)
from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration as OldWhisperForConditionalGeneration
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm

class WhisperAWQForConditionalGeneration(BaseAWQForCausalLM):
layer_type = "WhisperDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
is_encoder_decoder = True

def get_input_embeds(self):
return self.model.model.decoder.embed_tokens

@staticmethod
def fuse_layers(model: OldWhisperForConditionalGeneration):
fuser = LlavaFuser(model)
fuser.fuse_transformer()

@staticmethod
def get_model_layers(model: OldWhisperForConditionalGeneration):
return model.model.decoder.layers

@staticmethod
def get_act_for_scaling(module: OldWhisperDecoderLayer):
return dict(
is_scalable=False
)

@staticmethod
def move_embed(model: OldWhisperForConditionalGeneration, device: str):
model.proj_out = model.get_output_embeddings().to(device)

@staticmethod
def get_layers_for_scaling(module: OldWhisperDecoderLayer, input_feat, module_kwargs):
layers = []

# attention input
layers.append(dict(
prev_op=module.self_attn_layer_norm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))

# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.out_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat['self_attn.out_proj'],
))

# linear 1
layers.append(dict(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat['fc1'],
))

layers.append(dict(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat['fc2'],
))

return layers

11 changes: 9 additions & 2 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,16 +352,23 @@ def forward(self, *args, **kwargs):

# patch layer 0 to catch input and kwargs
modules[0] = Catcher(modules[0])


try:
self.model(samples.to(next(self.model.parameters()).device))
if not self.awq_model.is_encoder_decoder:
device = next(self.awq_model.parameters()).device
self.model(samples.to(device))
else:
device = self.awq_model.get_input_embeds().weight.device
modules[0](self.awq_model.get_input_embeds()(samples.to(device)))
except ValueError: # work with early exit
pass

# Update the layer kwargs with `prepare_inputs_for_generation` method
# that takes care of everything to avoid unexpected errors.
layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
# Pop the input_ids as they are not needed at all.
layer_kwargs.pop("input_ids")
layer_kwargs.pop("input_ids", None)

del samples
modules[0] = modules[0].module # restore
Expand Down
24 changes: 24 additions & 0 deletions examples/whisper_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'openai/whisper-large-v3'
quant_path = 'whisper-large-awq'
modules_to_not_convert = ["encoder_attn", "encoder"]

quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}

# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config, modules_to_not_convert=modules_to_not_convert)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')