From 12d2882f02c6da8611137f840934b6415a0aae91 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 7 Jun 2024 17:36:16 -0400 Subject: [PATCH] Change act_scale -> input_scale --- README.md | 2 +- auto_fp8/modeling.py | 2 +- auto_fp8/quantize.py | 24 ++++++++++++------------ examples/quantize.py | 24 ++++++++++++------------ 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 6e79275..38144a3 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ Each quantized layer in the state_dict will have: If the config has `"activation_scheme": "static"`: ``` model.layers.0.mlp.down_proj.weight < F8_E4M3 -model.layers.0.mlp.down_proj.act_scale < F32 +model.layers.0.mlp.down_proj.input_scale < F32 model.layers.0.mlp.down_proj.weight_scale < F32 ``` If config has `"activation_scheme": "dynamic"`: diff --git a/auto_fp8/modeling.py b/auto_fp8/modeling.py index 95667d5..69ca3b8 100644 --- a/auto_fp8/modeling.py +++ b/auto_fp8/modeling.py @@ -115,7 +115,7 @@ def _prepare_calibration_data(calibration_tokens): # import copy # for layer in self.model.model.layers: - # layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.act_scale) + # layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.input_scale) def save_quantized(self, save_dir): save_quantized_model( diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index c6e9099..3a1eb01 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -104,18 +104,18 @@ def __init__( super().__init__() self.weight = torch.nn.Parameter(qweight, requires_grad=False) self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - self.act_scale = None + self.input_scale = None self.bias = bias def forward(self, x): - qinput, x_act_scale = per_tensor_quantize(x) - if self.act_scale is None: - self.act_scale = torch.nn.Parameter(x_act_scale) - elif x_act_scale > self.act_scale: - self.act_scale = torch.nn.Parameter(x_act_scale) + qinput, x_input_scale = per_tensor_quantize(x) + if self.input_scale is None: + self.input_scale = torch.nn.Parameter(x_input_scale) + elif x_input_scale > self.input_scale: + self.input_scale = torch.nn.Parameter(x_input_scale) output = fp8_gemm( A=qinput, - A_scale=self.act_scale, + A_scale=self.input_scale, B=self.weight, B_scale=self.weight_scale, bias=self.bias, @@ -130,12 +130,12 @@ def __init__( qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor, - act_scale: float = 1.0, + input_scale: float = 1.0, ): super().__init__() self.weight = torch.nn.Parameter(qweight, requires_grad=False) self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False) + self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False) self.bias = bias def per_tensor_quantize( @@ -146,10 +146,10 @@ def per_tensor_quantize( return qweight.to(torch.float8_e4m3fn) def forward(self, x): - qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale) + qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale) output = fp8_gemm( A=qinput, - A_scale=self.act_scale, + A_scale=self.input_scale, B=self.weight, B_scale=self.weight_scale, bias=self.bias, @@ -247,7 +247,7 @@ def quantize_activations( quantizer.weight, quantizer.weight_scale, quantizer.bias, - quantizer.act_scale, + quantizer.input_scale, ) replace_module(model, name, static_proj) del quantizer diff --git a/examples/quantize.py b/examples/quantize.py index d3fb840..9e6e7f9 100644 --- a/examples/quantize.py +++ b/examples/quantize.py @@ -85,23 +85,23 @@ def __init__(self, qweight, weight_scale, bias): super().__init__() self.weight = torch.nn.Parameter(qweight, requires_grad=False) self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - self.act_scale = None + self.input_scale = None self.bias = bias def forward(self, x): # Dynamically quantize - qinput, x_act_scale = per_tensor_quantize(x) + qinput, x_input_scale = per_tensor_quantize(x) # Update scale if needed. - if self.act_scale is None: - self.act_scale = torch.nn.Parameter(x_act_scale) - elif x_act_scale > self.act_scale: - self.act_scale = torch.nn.Parameter(x_act_scale) + if self.input_scale is None: + self.input_scale = torch.nn.Parameter(x_input_scale) + elif x_input_scale > self.input_scale: + self.input_scale = torch.nn.Parameter(x_input_scale) # Pass quantized to next layer so it has realistic data. output = fp8_gemm( A=qinput, - A_scale=self.act_scale, + A_scale=self.input_scale, B=self.weight, B_scale=self.weight_scale, bias=self.bias, @@ -111,11 +111,11 @@ def forward(self, x): class FP8StaticLinear(torch.nn.Module): - def __init__(self, qweight, weight_scale, bias, act_scale=0.0): + def __init__(self, qweight, weight_scale, bias, input_scale=0.0): super().__init__() self.weight = torch.nn.Parameter(qweight, requires_grad=False) self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False) + self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False) self.bias = bias def per_tensor_quantize( @@ -129,10 +129,10 @@ def per_tensor_quantize( return qweight.to(torch.float8_e4m3fn) def forward(self, x): - qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale) + qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale) output = fp8_gemm( A=qinput, - A_scale=self.act_scale, + A_scale=self.input_scale, B=self.weight, B_scale=self.weight_scale, bias=self.bias, @@ -216,7 +216,7 @@ def quantize_activations(model, calibration_tokens): quantizer.weight, quantizer.weight_scale, quantizer.bias, - quantizer.act_scale, + quantizer.input_scale, ) replace_module(model, name, static_proj) del quantizer