Skip to content

Commit

Permalink
Support calibrating kv cache scales
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jun 13, 2024
1 parent 2e134d8 commit b015904
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 59 deletions.
9 changes: 7 additions & 2 deletions auto_fp8/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional, Tuple


class BaseQuantizeConfig:
Expand All @@ -17,13 +17,17 @@ class BaseQuantizeConfig:
regex style matching i.e. re.search(), for each Linear layer.
By default, "re:.*lm_head" is included to ignore the embedding
Linear layer usually at the end of decoder LLMs
kv_cache_quant_targets: Tuple of Linear module names to target for
calibration of the output scales for KV cache quantization.
Usually, these should be `("k_proj", "v_proj")`.
"""

def __init__(
self,
quant_method: str = "fp8",
activation_scheme: str = "static",
ignore_patterns: List[str] = [],
ignore_patterns: List[str] = ["re:.*lm_head"],
kv_cache_quant_targets: Optional[Tuple[str]] = None,
):
if quant_method != "fp8":
raise ValueError("Only FP8 quantization is supported.")
Expand All @@ -34,4 +38,5 @@ def __init__(
self.quant_method = quant_method
self.activation_scheme = activation_scheme
self.ignore_patterns = ignore_patterns
self.kv_cache_quant_targets = kv_cache_quant_targets
self.ignored_layers = []
34 changes: 29 additions & 5 deletions auto_fp8/modeling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import List
from typing import List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -27,6 +27,16 @@ def __init__(
self.model, quantize_config.ignore_patterns
)

if quantize_config.kv_cache_quant_targets:
kv_cache_quant_layers = get_kv_cache_quant_layer(
self.model, quantize_config.kv_cache_quant_targets
)
if len(kv_cache_quant_layers) == 0:
raise ValueError(
f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument."
)
quantize_config.kv_cache_quant_layers = kv_cache_quant_layers

self.quantize_config = quantize_config

@classmethod
Expand Down Expand Up @@ -97,7 +107,7 @@ def skip(*args, **kwargs):

return cls(model, quantize_config)

def quantize(self, calibration_tokens):
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
def _prepare_calibration_data(calibration_tokens):
if hasattr(calibration_tokens, "input_ids"):
return calibration_tokens.input_ids
Expand All @@ -107,6 +117,9 @@ def _prepare_calibration_data(calibration_tokens):
quantize_weights(self.model, self.quantize_config)

if self.quantize_config.activation_scheme == "static":
assert (
calibration_tokens is not None
), "Calibration tokens required for activation quantization"
quantize_activations(
self.model,
self.quantize_config,
Expand All @@ -128,9 +141,6 @@ def save_quantized(self, save_dir):
def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers = set()

# TODO: don't always ignore lm_head
ignore_patterns.append("re:.*lm_head")

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue
Expand All @@ -148,3 +158,17 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers.add(name)

return list(ignored_layers)


def get_kv_cache_quant_layer(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
kv_cache_quant_layers = set()

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue

for output_quant_target in kv_cache_quant_targets:
if name.endswith(output_quant_target):
kv_cache_quant_layers.add(name)

return list(kv_cache_quant_layers)
136 changes: 88 additions & 48 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc
import re
from typing import List, Tuple
from typing import Optional, Tuple

import torch
import tqdm
Expand Down Expand Up @@ -60,13 +60,21 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
return qweight, scale


def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)


def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
if A.numel() == 0:
# Deal with empty tensors (triggeted by empty MoE experts)
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)

native_fp8_support = (
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (8, 9)
and False
)
if native_fp8_support:
need_reshape = A.dim() == 3
Expand Down Expand Up @@ -97,84 +105,108 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
return output


class FP8StaticLinearQuantizer(torch.nn.Module):
# Class responsible for quantizing weights
class FP8DynamicLinear(torch.nn.Module):
def __init__(
self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor
self,
qweight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.nn.Parameter,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.input_scale = None
self.bias = bias

def forward(self, x):
qinput, x_input_scale = per_tensor_quantize(x)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale)
qinput, x_scale = per_tensor_quantize(x)
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
B=self.weight,
A_scale=x_scale,
B=self.qweight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
)
return output


class FP8StaticLinear(torch.nn.Module):
# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
class FP8StaticLinearQuantizer(torch.nn.Module):
def __init__(
self,
qweight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.Tensor,
input_scale: float = 1.0,
bias: torch.nn.Parameter,
quantize_output: bool = False,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False)
self.bias = bias

def per_tensor_quantize(
self, tensor: torch.Tensor, inv_scale: float
) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)
self.input_scale = None
self.output_scale = None
self.quantize_output = quantize_output

def forward(self, x):
qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale)
qinput, x_input_scale = per_tensor_quantize(x)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale)
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
B=self.weight,
B=self.qweight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
)

# Optionally, quantize output and record scale
if self.quantize_output:
qoutput, output_scale = per_tensor_quantize(output)
if self.output_scale is None:
self.output_scale = torch.nn.Parameter(output_scale)
elif output_scale > self.output_scale:
self.output_scale = torch.nn.Parameter(output_scale)
output = qoutput.to(output.dtype) * output_scale

return output


class FP8DynamicLinear(torch.nn.Module):
def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor):
# Module responsible for representing the final checkpoint representation
class FP8StaticLinear(torch.nn.Module):
def __init__(
self,
qweight: torch.nn.Parameter,
weight_scale: torch.nn.Parameter,
bias: torch.nn.Parameter,
input_scale: torch.nn.Parameter,
output_scale: Optional[torch.nn.Parameter] = None,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
self.qweight = qweight
self.weight_scale = weight_scale
self.bias = bias
self.input_scale = input_scale
self.output_scale = output_scale

def forward(self, x):
qinput, x_scale = per_tensor_quantize(x)
qinput = static_per_tensor_quantize(x, self.input_scale)
output = fp8_gemm(
A=qinput,
A_scale=x_scale,
B=self.weight,
A_scale=self.input_scale,
B=self.qweight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
)

if self.output_scale:
qoutput = static_per_tensor_quantize(output, self.output_scale)
output = qoutput.to(output.dtype) * self.output_scale

return output


Expand All @@ -193,7 +225,6 @@ def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.
def quantize_weights(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
ignored_layers: List[str] = [],
):
named_modules = list(model.named_modules())
for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"):
Expand All @@ -202,9 +233,11 @@ def quantize_weights(
or name in quantize_config.ignored_layers
):
continue
quant_weight, quant_scale = per_tensor_quantize(linear.weight.clone())
quant_weight, weight_scale = per_tensor_quantize(linear.weight.clone())
bias = linear.bias.clone() if linear.bias is not None else None
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, bias)
quant_linear = FP8DynamicLinear(
qweight=quant_weight, weight_scale=weight_scale, bias=bias
)
replace_module(model, name, quant_linear)
del linear.weight
del linear.bias
Expand All @@ -216,7 +249,6 @@ def quantize_activations(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
calibration_tokens,
ignored_layers: List[str] = [],
):
# Replace weight quantizer with a dynamic activation quantizer observer
for name, dynamic_quant_linear in model.named_modules():
Expand All @@ -226,16 +258,22 @@ def quantize_activations(
):
continue
quantizer = FP8StaticLinearQuantizer(
dynamic_quant_linear.weight,
dynamic_quant_linear.weight_scale,
dynamic_quant_linear.bias,
qweight=dynamic_quant_linear.qweight,
weight_scale=dynamic_quant_linear.weight_scale,
bias=dynamic_quant_linear.bias,
quantize_output=(
hasattr(quantize_config, "kv_cache_quant_layers")
and name in quantize_config.kv_cache_quant_layers
),
)
replace_module(model, name, quantizer)
del dynamic_quant_linear
cleanup_memory()

# Pass through calibration data to measure activation scales
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
with tqdm.tqdm(
total=calibration_tokens.shape[0], desc="Calibrating activation scales"
) as pbar:
for row_idx in range(calibration_tokens.shape[0]):
model(calibration_tokens[row_idx].reshape(1, -1))
cleanup_memory()
Expand All @@ -249,10 +287,11 @@ def quantize_activations(
):
continue
static_proj = FP8StaticLinear(
quantizer.weight,
quantizer.weight_scale,
quantizer.bias,
quantizer.input_scale,
qweight=quantizer.qweight,
weight_scale=quantizer.weight_scale,
bias=quantizer.bias,
input_scale=quantizer.input_scale,
output_scale=quantizer.output_scale,
)
replace_module(model, name, static_proj)
del quantizer
Expand All @@ -263,7 +302,6 @@ def save_quantized_model(
model: AutoModelForCausalLM,
quant_config: BaseQuantizeConfig,
save_dir: str,
ignored_layers: List[str] = [],
):
print(model)
print(f"Saving the model to {save_dir}")
Expand All @@ -274,6 +312,8 @@ def save_quantized_model(
"ignored_layers": quant_config.ignored_layers,
}
}
if hasattr(quant_config, "kv_cache_quant_layers"):
static_q_dict["quantization_config"]["kv_cache_scheme"] = "static"
model.config.update(static_q_dict)
model.save_pretrained(save_dir)
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
Expand Down
Loading

0 comments on commit b015904

Please sign in to comment.