From 0d40b99f0c3441fc45f756a1d84dcec9ce6cbf83 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 18 Jun 2024 18:53:01 -0400 Subject: [PATCH] Support calibrating kv cache scales (#17) * Support calibrating kv cache scales * Add comment * Fix weight name * Add Qwen test * Fix kv cache test count * Add fixed target sizes * Fix proj linear count * Switch from output_scale to kv_scale * Add example --- auto_fp8/config.py | 9 +- auto_fp8/modeling.py | 49 ++++++--- auto_fp8/quantize.py | 156 ++++++++++++++++++++--------- examples/example_static_kvcache.py | 25 +++++ tests/test_auto_fp8.py | 79 +++++++++++++-- 5 files changed, 249 insertions(+), 69 deletions(-) create mode 100644 examples/example_static_kvcache.py diff --git a/auto_fp8/config.py b/auto_fp8/config.py index 7f8dd95..24c6200 100644 --- a/auto_fp8/config.py +++ b/auto_fp8/config.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional, Tuple class BaseQuantizeConfig: @@ -17,13 +17,17 @@ class BaseQuantizeConfig: regex style matching i.e. re.search(), for each Linear layer. By default, "re:.*lm_head" is included to ignore the embedding Linear layer usually at the end of decoder LLMs + kv_cache_quant_targets: Tuple of Linear module names to target for + calibration of the output scales for KV cache quantization. + Usually, these should be `("k_proj", "v_proj")`. """ def __init__( self, quant_method: str = "fp8", activation_scheme: str = "static", - ignore_patterns: List[str] = [], + ignore_patterns: List[str] = ["re:.*lm_head"], + kv_cache_quant_targets: Optional[Tuple[str]] = None, ): if quant_method != "fp8": raise ValueError("Only FP8 quantization is supported.") @@ -34,4 +38,5 @@ def __init__( self.quant_method = quant_method self.activation_scheme = activation_scheme self.ignore_patterns = ignore_patterns + self.kv_cache_quant_targets = kv_cache_quant_targets self.ignored_layers = [] diff --git a/auto_fp8/modeling.py b/auto_fp8/modeling.py index 340a598..04a9e71 100644 --- a/auto_fp8/modeling.py +++ b/auto_fp8/modeling.py @@ -1,5 +1,5 @@ import re -from typing import List +from typing import List, Optional, Tuple import torch from transformers import AutoModelForCausalLM @@ -27,6 +27,16 @@ def __init__( self.model, quantize_config.ignore_patterns ) + if quantize_config.kv_cache_quant_targets: + kv_cache_quant_layers = get_kv_cache_quant_layers( + self.model, quantize_config.kv_cache_quant_targets + ) + if len(kv_cache_quant_layers) == 0: + raise ValueError( + f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument." + ) + quantize_config.kv_cache_quant_layers = kv_cache_quant_layers + self.quantize_config = quantize_config @classmethod @@ -97,26 +107,28 @@ def skip(*args, **kwargs): return cls(model, quantize_config) - def quantize(self, calibration_tokens): - def _prepare_calibration_data(calibration_tokens): - if hasattr(calibration_tokens, "input_ids"): - return calibration_tokens.input_ids - return calibration_tokens + def quantize(self, calibration_tokens: Optional[torch.Tensor] = None): # Always quantize the weights as they do not require calibration data quantize_weights(self.model, self.quantize_config) if self.quantize_config.activation_scheme == "static": + assert ( + calibration_tokens is not None + ), "Calibration tokens required for activation quantization" + + + def _prepare_calibration_data(calibration_tokens): + if hasattr(calibration_tokens, "input_ids"): + return calibration_tokens.input_ids + return calibration_tokens + quantize_activations( self.model, self.quantize_config, _prepare_calibration_data(calibration_tokens), ) - # import copy - # for layer in self.model.model.layers: - # layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.input_scale) - def save_quantized(self, save_dir): save_quantized_model( self.model, @@ -128,9 +140,6 @@ def save_quantized(self, save_dir): def get_layers_to_ignore(model, ignore_patterns) -> List[str]: ignored_layers = set() - # TODO: don't always ignore lm_head - ignore_patterns.append("re:.*lm_head") - for name, linear in model.named_modules(): if not isinstance(linear, torch.nn.Linear): continue @@ -148,3 +157,17 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]: ignored_layers.add(name) return list(ignored_layers) + + +def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]: + kv_cache_quant_layers = [] + + for name, linear in model.named_modules(): + if not isinstance(linear, torch.nn.Linear): + continue + + for output_quant_target in kv_cache_quant_targets: + if name.endswith(output_quant_target): + kv_cache_quant_layers.append(name) + + return kv_cache_quant_layers diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 4c1b580..38a4de6 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -1,6 +1,6 @@ import gc import re -from typing import List, Tuple +from typing import Optional, Tuple import copy import torch @@ -61,14 +61,22 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: return qweight, scale +def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fn) + + def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): if A.numel() == 0: # Deal with empty tensors (triggeted by empty MoE experts) return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device) - - native_fp8_support = ( - torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - ) + + # TODO: Disable native fp8 gemm for now, always just dequantize + # native_fp8_support = ( + # torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) + # ) + native_fp8_support = False if native_fp8_support: need_reshape = A.dim() == 3 if need_reshape: @@ -98,25 +106,24 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): return output -class FP8StaticLinearQuantizer(torch.nn.Module): +# Class responsible for quantizing weights +class FP8DynamicLinear(torch.nn.Module): def __init__( - self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor + self, + weight: torch.Tensor, + weight_scale: torch.Tensor, + bias: torch.nn.Parameter, ): super().__init__() - self.weight = torch.nn.Parameter(qweight, requires_grad=False) + self.weight = torch.nn.Parameter(weight, requires_grad=False) self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - self.input_scale = None self.bias = bias def forward(self, x): - qinput, x_input_scale = per_tensor_quantize(x) - if self.input_scale is None: - self.input_scale = torch.nn.Parameter(x_input_scale) - elif x_input_scale > self.input_scale: - self.input_scale = torch.nn.Parameter(x_input_scale) + qinput, x_scale = per_tensor_quantize(x) output = fp8_gemm( A=qinput, - A_scale=self.input_scale, + A_scale=x_scale, B=self.weight, B_scale=self.weight_scale, bias=self.bias, @@ -125,29 +132,29 @@ def forward(self, x): return output -class FP8StaticLinear(torch.nn.Module): +# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer +class FP8StaticLinearQuantizer(torch.nn.Module): def __init__( self, - qweight: torch.Tensor, + weight: torch.Tensor, weight_scale: torch.Tensor, - bias: torch.Tensor, - input_scale: float = 1.0, + bias: torch.nn.Parameter, + quantize_output: bool = False, ): super().__init__() - self.weight = torch.nn.Parameter(qweight, requires_grad=False) + self.weight = torch.nn.Parameter(weight, requires_grad=False) self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False) self.bias = bias - - def per_tensor_quantize( - self, tensor: torch.Tensor, inv_scale: float - ) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) - return qweight.to(torch.float8_e4m3fn) + self.input_scale = None + self.output_scale = None + self.quantize_output = quantize_output def forward(self, x): - qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale) + qinput, x_input_scale = per_tensor_quantize(x) + if self.input_scale is None: + self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False) + elif x_input_scale > self.input_scale: + self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False) output = fp8_gemm( A=qinput, A_scale=self.input_scale, @@ -156,26 +163,51 @@ def forward(self, x): bias=self.bias, out_dtype=x.dtype, ) + + # Optionally, quantize output and record scale + if self.quantize_output: + qoutput, output_scale = per_tensor_quantize(output) + if self.output_scale is None: + self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False) + elif output_scale > self.output_scale: + self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False) + output = qoutput.to(output.dtype) * output_scale + return output -class FP8DynamicLinear(torch.nn.Module): - def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor): +# Module responsible for representing the final checkpoint representation +class FP8StaticLinear(torch.nn.Module): + def __init__( + self, + weight: torch.nn.Parameter, + weight_scale: torch.nn.Parameter, + bias: torch.nn.Parameter, + input_scale: torch.nn.Parameter, + output_scale: Optional[torch.nn.Parameter] = None, + ): super().__init__() - self.weight = torch.nn.Parameter(qweight, requires_grad=False) - self.weight_scale = torch.nn.Parameter(scale, requires_grad=False) + self.weight = weight + self.weight_scale = weight_scale self.bias = bias + self.input_scale = input_scale + self.output_scale = output_scale def forward(self, x): - qinput, x_scale = per_tensor_quantize(x) + qinput = static_per_tensor_quantize(x, self.input_scale) output = fp8_gemm( A=qinput, - A_scale=x_scale, + A_scale=self.input_scale, B=self.weight, B_scale=self.weight_scale, bias=self.bias, out_dtype=x.dtype, ) + + if self.output_scale: + qoutput = static_per_tensor_quantize(output, self.output_scale) + output = qoutput.to(output.dtype) * self.output_scale + return output @@ -194,7 +226,6 @@ def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn. def quantize_weights( model: AutoModelForCausalLM, quantize_config: BaseQuantizeConfig, - ignored_layers: List[str] = [], ): named_modules = list(model.named_modules()) for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"): @@ -203,9 +234,11 @@ def quantize_weights( or name in quantize_config.ignored_layers ): continue - quant_weight, quant_scale = per_tensor_quantize(linear.weight) + quant_weight, weight_scale = per_tensor_quantize(linear.weight) bias = copy.deepcopy(linear.bias) if linear.bias is not None else None - quant_linear = FP8DynamicLinear(quant_weight, quant_scale, bias) + quant_linear = FP8DynamicLinear( + weight=quant_weight, weight_scale=weight_scale, bias=bias + ) replace_module(model, name, quant_linear) del linear.weight del linear.bias @@ -217,7 +250,6 @@ def quantize_activations( model: AutoModelForCausalLM, quantize_config: BaseQuantizeConfig, calibration_tokens, - ignored_layers: List[str] = [], ): # Replace weight quantizer with a dynamic activation quantizer observer for name, dynamic_quant_linear in model.named_modules(): @@ -227,9 +259,13 @@ def quantize_activations( ): continue quantizer = FP8StaticLinearQuantizer( - dynamic_quant_linear.weight, - dynamic_quant_linear.weight_scale, - dynamic_quant_linear.bias, + weight=dynamic_quant_linear.weight, + weight_scale=dynamic_quant_linear.weight_scale, + bias=dynamic_quant_linear.bias, + quantize_output=( + hasattr(quantize_config, "kv_cache_quant_layers") + and name in quantize_config.kv_cache_quant_layers + ), ) replace_module(model, name, quantizer) del dynamic_quant_linear @@ -251,21 +287,45 @@ def quantize_activations( ): continue static_proj = FP8StaticLinear( - quantizer.weight, - quantizer.weight_scale, - quantizer.bias, - quantizer.input_scale, + weight=quantizer.weight, + weight_scale=quantizer.weight_scale, + bias=quantizer.bias, + input_scale=quantizer.input_scale, + output_scale=quantizer.output_scale, ) replace_module(model, name, static_proj) del quantizer cleanup_memory() + # Post-process step for kv cache scales to take the k/v module + # `output_scale` parameters, take the max of them, and store them in + # the parent attention module as `kv_scale` + # NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block + if hasattr(quantize_config, "kv_cache_quant_layers"): + # Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...] + # so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...] + kv_proj_pairs = zip(*[iter(quantize_config.kv_cache_quant_layers)]*2) + for k_proj_name, v_proj_name in kv_proj_pairs: + parent_module_name = ".".join(k_proj_name.split(".")[:-1]) + assert parent_module_name == ".".join(v_proj_name.split(".")[:-1]) + parent_module = dict(model.named_modules())[parent_module_name] + + k_proj = dict(model.named_modules())[k_proj_name] + v_proj = dict(model.named_modules())[v_proj_name] + + kv_scale = max(k_proj.output_scale, v_proj.output_scale) + parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False) + + # Remove output_scale from k_proj and v_proj + k_proj.output_scale = None + v_proj.output_scale = None + cleanup_memory() + def save_quantized_model( model: AutoModelForCausalLM, quant_config: BaseQuantizeConfig, save_dir: str, - ignored_layers: List[str] = [], ): print(model) print(f"Saving the model to {save_dir}") @@ -276,6 +336,8 @@ def save_quantized_model( "ignored_layers": quant_config.ignored_layers, } } + if hasattr(quant_config, "kv_cache_quant_layers"): + static_q_dict["quantization_config"]["kv_cache_scheme"] = "static" model.config.update(static_q_dict) model.save_pretrained(save_dir) tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) diff --git a/examples/example_static_kvcache.py b/examples/example_static_kvcache.py new file mode 100644 index 0000000..118bad5 --- /dev/null +++ b/examples/example_static_kvcache.py @@ -0,0 +1,25 @@ +from datasets import load_dataset +from transformers import AutoTokenizer + +from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig + +pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" +quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8-KV" + +tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) +tokenizer.pad_token = tokenizer.eos_token + +ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512)) +examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds] +examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda") + +quantize_config = BaseQuantizeConfig( + quant_method="fp8", + activation_scheme="static", + ignore_patterns=["re:.*lm_head"], + kv_cache_quant_targets=("k_proj", "v_proj"), +) + +model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) +model.quantize(examples) +model.save_quantized(quantized_model_dir) diff --git a/tests/test_auto_fp8.py b/tests/test_auto_fp8.py index 51db3c1..6045d84 100644 --- a/tests/test_auto_fp8.py +++ b/tests/test_auto_fp8.py @@ -1,14 +1,43 @@ import os import shutil +import pytest +import safetensors.torch from transformers import AutoTokenizer from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig +MODELS = [ + ("facebook/opt-125m", 160), + ("Qwen/Qwen2-0.5B-Instruct", 620), +] -def test_quantization(): - model_id = "facebook/opt-125m" - quantized_model_dir = "opt-125m-fp8" +@pytest.mark.parametrize("model_id,target_size", MODELS) +def test_dynamic_quantization(model_id, target_size): + quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic" + + quantize_config = BaseQuantizeConfig( + quant_method="fp8", activation_scheme="dynamic" + ) + + model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) + model.model.to("cpu") + + model.quantize() + model.save_quantized(quantized_model_dir) + + # Measure checkpoint size and cleanup + model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") + shutil.rmtree(quantized_model_dir) + + # We expect the quantized model to be a certain size + target_size = target_size * (1024 * 1024) + assert model_size < target_size + + +@pytest.mark.parametrize("model_id,target_size", MODELS) +def test_static_quantization(model_id, target_size): + quantized_model_dir = model_id.split("/")[-1] + "-fp8-static" tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) examples = ["auto-fp8 is an easy-to-use model quantization library"] @@ -16,18 +45,54 @@ def test_quantization(): quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static") - model = AutoFP8ForCausalLM.from_pretrained( - model_id, quantize_config=quantize_config + model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) + model.model.to("cpu") + + model.quantize(examples) + model.save_quantized(quantized_model_dir) + + # Measure checkpoint size and cleanup + model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") + shutil.rmtree(quantized_model_dir) + + # We expect the quantized model to be a certain size + target_size = target_size * (1024 * 1024) + assert model_size < target_size + +@pytest.mark.parametrize("model_id,target_size", MODELS) +def test_kv_cache_static_quantization(model_id, target_size): + quantized_model_dir = model_id.split("/")[-1] + "-fp8-static-kv" + + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) + examples = ["auto-fp8 is an easy-to-use model quantization library"] + examples = tokenizer(examples, return_tensors="pt") + + quantize_config = BaseQuantizeConfig( + quant_method="fp8", + activation_scheme="static", + kv_cache_quant_targets=("k_proj", "v_proj"), ) + + model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config) model.model.to("cpu") model.quantize(examples) model.save_quantized(quantized_model_dir) + tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors") + proj_linear_count = 0 + kv_scale_count = 0 + for name, _ in tensors.items(): + if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"): + proj_linear_count += 1 + if name.endswith("kv_scale"): + kv_scale_count += 1 + assert proj_linear_count // 2 == kv_scale_count + # Measure checkpoint size and cleanup model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors") shutil.rmtree(quantized_model_dir) - # We expect the model to be < 160MB - target_size = 160 * (1024 * 1024) + # We expect the quantized model to be a certain size + target_size = target_size * (1024 * 1024) assert model_size < target_size