Skip to content

Commit

Permalink
Support calibrating kv cache scales (#17)
Browse files Browse the repository at this point in the history
* Support calibrating kv cache scales

* Add comment

* Fix weight name

* Add Qwen test

* Fix kv cache test count

* Add fixed target sizes

* Fix proj linear count

* Switch from output_scale to kv_scale

* Add example
  • Loading branch information
mgoin authored Jun 18, 2024
1 parent b1c6ad6 commit 0d40b99
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 69 deletions.
9 changes: 7 additions & 2 deletions auto_fp8/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional, Tuple


class BaseQuantizeConfig:
Expand All @@ -17,13 +17,17 @@ class BaseQuantizeConfig:
regex style matching i.e. re.search(), for each Linear layer.
By default, "re:.*lm_head" is included to ignore the embedding
Linear layer usually at the end of decoder LLMs
kv_cache_quant_targets: Tuple of Linear module names to target for
calibration of the output scales for KV cache quantization.
Usually, these should be `("k_proj", "v_proj")`.
"""

def __init__(
self,
quant_method: str = "fp8",
activation_scheme: str = "static",
ignore_patterns: List[str] = [],
ignore_patterns: List[str] = ["re:.*lm_head"],
kv_cache_quant_targets: Optional[Tuple[str]] = None,
):
if quant_method != "fp8":
raise ValueError("Only FP8 quantization is supported.")
Expand All @@ -34,4 +38,5 @@ def __init__(
self.quant_method = quant_method
self.activation_scheme = activation_scheme
self.ignore_patterns = ignore_patterns
self.kv_cache_quant_targets = kv_cache_quant_targets
self.ignored_layers = []
49 changes: 36 additions & 13 deletions auto_fp8/modeling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import List
from typing import List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -27,6 +27,16 @@ def __init__(
self.model, quantize_config.ignore_patterns
)

if quantize_config.kv_cache_quant_targets:
kv_cache_quant_layers = get_kv_cache_quant_layers(
self.model, quantize_config.kv_cache_quant_targets
)
if len(kv_cache_quant_layers) == 0:
raise ValueError(
f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument."
)
quantize_config.kv_cache_quant_layers = kv_cache_quant_layers

self.quantize_config = quantize_config

@classmethod
Expand Down Expand Up @@ -97,26 +107,28 @@ def skip(*args, **kwargs):

return cls(model, quantize_config)

def quantize(self, calibration_tokens):
def _prepare_calibration_data(calibration_tokens):
if hasattr(calibration_tokens, "input_ids"):
return calibration_tokens.input_ids
return calibration_tokens
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):

# Always quantize the weights as they do not require calibration data
quantize_weights(self.model, self.quantize_config)

if self.quantize_config.activation_scheme == "static":
assert (
calibration_tokens is not None
), "Calibration tokens required for activation quantization"


def _prepare_calibration_data(calibration_tokens):
if hasattr(calibration_tokens, "input_ids"):
return calibration_tokens.input_ids
return calibration_tokens

quantize_activations(
self.model,
self.quantize_config,
_prepare_calibration_data(calibration_tokens),
)

# import copy
# for layer in self.model.model.layers:
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.input_scale)

def save_quantized(self, save_dir):
save_quantized_model(
self.model,
Expand All @@ -128,9 +140,6 @@ def save_quantized(self, save_dir):
def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers = set()

# TODO: don't always ignore lm_head
ignore_patterns.append("re:.*lm_head")

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue
Expand All @@ -148,3 +157,17 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers.add(name)

return list(ignored_layers)


def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
kv_cache_quant_layers = []

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue

for output_quant_target in kv_cache_quant_targets:
if name.endswith(output_quant_target):
kv_cache_quant_layers.append(name)

return kv_cache_quant_layers
156 changes: 109 additions & 47 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc
import re
from typing import List, Tuple
from typing import Optional, Tuple
import copy

import torch
Expand Down Expand Up @@ -61,14 +61,22 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
return qweight, scale


def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)


def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
if A.numel() == 0:
# Deal with empty tensors (triggeted by empty MoE experts)
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)

native_fp8_support = (
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
)

# TODO: Disable native fp8 gemm for now, always just dequantize
# native_fp8_support = (
# torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
# )
native_fp8_support = False
if native_fp8_support:
need_reshape = A.dim() == 3
if need_reshape:
Expand Down Expand Up @@ -98,25 +106,24 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
return output


class FP8StaticLinearQuantizer(torch.nn.Module):
# Class responsible for quantizing weights
class FP8DynamicLinear(torch.nn.Module):
def __init__(
self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor
self,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.nn.Parameter,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.input_scale = None
self.bias = bias

def forward(self, x):
qinput, x_input_scale = per_tensor_quantize(x)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale)
qinput, x_scale = per_tensor_quantize(x)
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
A_scale=x_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
Expand All @@ -125,29 +132,29 @@ def forward(self, x):
return output


class FP8StaticLinear(torch.nn.Module):
# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
class FP8StaticLinearQuantizer(torch.nn.Module):
def __init__(
self,
qweight: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.Tensor,
input_scale: float = 1.0,
bias: torch.nn.Parameter,
quantize_output: bool = False,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False)
self.bias = bias

def per_tensor_quantize(
self, tensor: torch.Tensor, inv_scale: float
) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)
self.input_scale = None
self.output_scale = None
self.quantize_output = quantize_output

def forward(self, x):
qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale)
qinput, x_input_scale = per_tensor_quantize(x)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
Expand All @@ -156,26 +163,51 @@ def forward(self, x):
bias=self.bias,
out_dtype=x.dtype,
)

# Optionally, quantize output and record scale
if self.quantize_output:
qoutput, output_scale = per_tensor_quantize(output)
if self.output_scale is None:
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
elif output_scale > self.output_scale:
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
output = qoutput.to(output.dtype) * output_scale

return output


class FP8DynamicLinear(torch.nn.Module):
def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor):
# Module responsible for representing the final checkpoint representation
class FP8StaticLinear(torch.nn.Module):
def __init__(
self,
weight: torch.nn.Parameter,
weight_scale: torch.nn.Parameter,
bias: torch.nn.Parameter,
input_scale: torch.nn.Parameter,
output_scale: Optional[torch.nn.Parameter] = None,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
self.weight = weight
self.weight_scale = weight_scale
self.bias = bias
self.input_scale = input_scale
self.output_scale = output_scale

def forward(self, x):
qinput, x_scale = per_tensor_quantize(x)
qinput = static_per_tensor_quantize(x, self.input_scale)
output = fp8_gemm(
A=qinput,
A_scale=x_scale,
A_scale=self.input_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
)

if self.output_scale:
qoutput = static_per_tensor_quantize(output, self.output_scale)
output = qoutput.to(output.dtype) * self.output_scale

return output


Expand All @@ -194,7 +226,6 @@ def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.
def quantize_weights(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
ignored_layers: List[str] = [],
):
named_modules = list(model.named_modules())
for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"):
Expand All @@ -203,9 +234,11 @@ def quantize_weights(
or name in quantize_config.ignored_layers
):
continue
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, bias)
quant_linear = FP8DynamicLinear(
weight=quant_weight, weight_scale=weight_scale, bias=bias
)
replace_module(model, name, quant_linear)
del linear.weight
del linear.bias
Expand All @@ -217,7 +250,6 @@ def quantize_activations(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
calibration_tokens,
ignored_layers: List[str] = [],
):
# Replace weight quantizer with a dynamic activation quantizer observer
for name, dynamic_quant_linear in model.named_modules():
Expand All @@ -227,9 +259,13 @@ def quantize_activations(
):
continue
quantizer = FP8StaticLinearQuantizer(
dynamic_quant_linear.weight,
dynamic_quant_linear.weight_scale,
dynamic_quant_linear.bias,
weight=dynamic_quant_linear.weight,
weight_scale=dynamic_quant_linear.weight_scale,
bias=dynamic_quant_linear.bias,
quantize_output=(
hasattr(quantize_config, "kv_cache_quant_layers")
and name in quantize_config.kv_cache_quant_layers
),
)
replace_module(model, name, quantizer)
del dynamic_quant_linear
Expand All @@ -251,21 +287,45 @@ def quantize_activations(
):
continue
static_proj = FP8StaticLinear(
quantizer.weight,
quantizer.weight_scale,
quantizer.bias,
quantizer.input_scale,
weight=quantizer.weight,
weight_scale=quantizer.weight_scale,
bias=quantizer.bias,
input_scale=quantizer.input_scale,
output_scale=quantizer.output_scale,
)
replace_module(model, name, static_proj)
del quantizer
cleanup_memory()

# Post-process step for kv cache scales to take the k/v module
# `output_scale` parameters, take the max of them, and store them in
# the parent attention module as `kv_scale`
# NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block
if hasattr(quantize_config, "kv_cache_quant_layers"):
# Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...]
# so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...]
kv_proj_pairs = zip(*[iter(quantize_config.kv_cache_quant_layers)]*2)
for k_proj_name, v_proj_name in kv_proj_pairs:
parent_module_name = ".".join(k_proj_name.split(".")[:-1])
assert parent_module_name == ".".join(v_proj_name.split(".")[:-1])
parent_module = dict(model.named_modules())[parent_module_name]

k_proj = dict(model.named_modules())[k_proj_name]
v_proj = dict(model.named_modules())[v_proj_name]

kv_scale = max(k_proj.output_scale, v_proj.output_scale)
parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False)

# Remove output_scale from k_proj and v_proj
k_proj.output_scale = None
v_proj.output_scale = None
cleanup_memory()


def save_quantized_model(
model: AutoModelForCausalLM,
quant_config: BaseQuantizeConfig,
save_dir: str,
ignored_layers: List[str] = [],
):
print(model)
print(f"Saving the model to {save_dir}")
Expand All @@ -276,6 +336,8 @@ def save_quantized_model(
"ignored_layers": quant_config.ignored_layers,
}
}
if hasattr(quant_config, "kv_cache_quant_layers"):
static_q_dict["quantization_config"]["kv_cache_scheme"] = "static"
model.config.update(static_q_dict)
model.save_pretrained(save_dir)
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
Expand Down
Loading

0 comments on commit 0d40b99

Please sign in to comment.