Skip to content

Commit

Permalink
Llama with inputs_embeds only(LLava-v1.5 bug fixed) and Llava-v1.6 Su…
Browse files Browse the repository at this point in the history
…pport (#471)

Co-authored-by: Casper <[email protected]>
  • Loading branch information
WanBenLe and casper-hansen authored Jul 23, 2024
1 parent 6919d7b commit b55f73c
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 17 deletions.
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .gemma import GemmaAWQForCausalLM
from .stablelm import StableLmAWQForCausalLM
from .starcoder2 import Starcoder2AWQForCausalLM
from .llava_next import LlavaNextAWQForCausalLM
from .phi3 import Phi3AWQForCausalLM
from .cohere import CohereAWQForCausalLM
from .deepseek_v2 import DeepseekV2AWQForCausalLM
Expand Down
2 changes: 2 additions & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from awq.models import *
from awq.models.base import BaseAWQForCausalLM


AWQ_CAUSAL_LM_MODEL_MAP = {
"mpt": MptAWQForCausalLM,
"llama": LlamaAWQForCausalLM,
Expand All @@ -26,6 +27,7 @@
"gemma": GemmaAWQForCausalLM,
"stablelm": StableLmAWQForCausalLM,
"starcoder2": Starcoder2AWQForCausalLM,
"llava_next": LlavaNextAWQForCausalLM,
"phi3": Phi3AWQForCausalLM,
"cohere": CohereAWQForCausalLM,
"deepseek_v2": DeepseekV2AWQForCausalLM,
Expand Down
1 change: 1 addition & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"gemma": "AutoModelForCausalLM",
"stablelm": "AutoModelForCausalLM",
"starcoder2": "AutoModelForCausalLM",
"llava_next": "AutoModelForVision2Seq",
"phi3": "AutoModelForCausalLM",
"cohere": "AutoModelForCausalLM",
"deepseek_v2": "AutoModelForCausalLM",
Expand Down
6 changes: 5 additions & 1 deletion awq/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def fuse_transformer(self):
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon,
)
if hasattr(self.model.config, "max_seq_len"):
max_seq_len = self.model.config.max_seq_len
else:
max_seq_len = self.model.config.max_position_embeddings
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
Expand All @@ -128,7 +132,7 @@ def fuse_transformer(self):
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
max_seq_len=max_seq_len,
)
)

Expand Down
145 changes: 145 additions & 0 deletions awq/models/llava_next.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
)
from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration
from awq.modules.fused.norm import FasterTransformerRMSNorm


class LlavaNextAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_seq_len_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: LlavaNextForConditionalGeneration):
pass

@staticmethod
def get_model_layers(model: LlavaNextForConditionalGeneration):
return model.language_model.model.layers

@staticmethod
def get_act_for_scaling(module: OldLlamaDecoderLayer):
return dict(is_scalable=False)

@staticmethod
def move_embed(model: LlavaNextForConditionalGeneration, device: str):
model.language_model.model.embed_tokens = model.get_input_embeddings().to(
device
)

@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
layers = []

# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)

# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)

# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)

# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

return layers


class LlavaNextFuser:
def __init__(self, model: LlavaNextForConditionalGeneration):
self.model = model.language_model

self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower()
]

def fuse_transformer(self):
blocks = []

module: OldLlamaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon,
)
if hasattr(self.model.config, "max_seq_len"):
max_seq_len = self.model.config.max_seq_len
else:
max_seq_len = self.model.config.max_position_embeddings
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=max_seq_len,
)
)

self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)



2 changes: 1 addition & 1 deletion awq/modules/fused/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,4 @@ def forward(
past_key_values=None,
hidden_states=(),
attentions=(),
)
)
64 changes: 49 additions & 15 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,34 @@ tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
```

### Vision-Language Models

AutoAWQ supports a few vision-language models. So far, we support LLaVa 1.5 and LLaVa 1.6 (next).

```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'llava-hf/llama3-llava-next-8b-hf'
quant_path = 'llama3-llava-next-8b-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, device_map="cuda", **{"low_cpu_mem_usage": True}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')
```

### GGUF Export

This computes AWQ scales and appliesthem to the model without running real quantization.
Expand Down Expand Up @@ -405,29 +433,35 @@ AutoAWQ also supports the LLaVa model. You simply need to load an
AutoProcessor to process the prompt and image to generate inputs for the AWQ model.

```python
import requests
import torch
import requests
from PIL import Image

from awq import AutoAWQForCausalLM
from transformers import AutoProcessor

quant_path = "ybelkada/llava-1.5-7b-hf-awq"
from transformers import AutoProcessor, TextStreamer

# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, safetensors=True, device_map={"": 0})
quant_path = "casperhansen/llama3-llava-next-8b-awq"
model = AutoAWQForCausalLM.from_quantized(quant_path)
processor = AutoProcessor.from_pretrained(quant_path)
streamer = TextStreamer(processor, skip_prompt=True)

prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
# Define prompt
prompt = """\
<|im_start|>system\nAnswer the questions.<|im_end|>
<|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>
<|im_start|>assistant
"""

# Define image
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)

# Load inputs
inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)

raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
# Generate output
generation_output = model.generate(
**inputs,
max_new_tokens=512
**inputs,
max_new_tokens=512,
streamer=streamer
)

print(processor.decode(generation_output[0], skip_special_tokens=True))
```

0 comments on commit b55f73c

Please sign in to comment.