Skip to content

Commit

Permalink
Change act_scale -> input_scale
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Jun 7, 2024
1 parent 009dc55 commit 12d2882
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Each quantized layer in the state_dict will have:
If the config has `"activation_scheme": "static"`:
```
model.layers.0.mlp.down_proj.weight < F8_E4M3
model.layers.0.mlp.down_proj.act_scale < F32
model.layers.0.mlp.down_proj.input_scale < F32
model.layers.0.mlp.down_proj.weight_scale < F32
```
If config has `"activation_scheme": "dynamic"`:
Expand Down
2 changes: 1 addition & 1 deletion auto_fp8/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _prepare_calibration_data(calibration_tokens):

# import copy
# for layer in self.model.model.layers:
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.act_scale)
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.input_scale)

def save_quantized(self, save_dir):
save_quantized_model(
Expand Down
24 changes: 12 additions & 12 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,18 @@ def __init__(
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.act_scale = None
self.input_scale = None
self.bias = bias

def forward(self, x):
qinput, x_act_scale = per_tensor_quantize(x)
if self.act_scale is None:
self.act_scale = torch.nn.Parameter(x_act_scale)
elif x_act_scale > self.act_scale:
self.act_scale = torch.nn.Parameter(x_act_scale)
qinput, x_input_scale = per_tensor_quantize(x)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale)
output = fp8_gemm(
A=qinput,
A_scale=self.act_scale,
A_scale=self.input_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
Expand All @@ -130,12 +130,12 @@ def __init__(
qweight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.Tensor,
act_scale: float = 1.0,
input_scale: float = 1.0,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False)
self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False)
self.bias = bias

def per_tensor_quantize(
Expand All @@ -146,10 +146,10 @@ def per_tensor_quantize(
return qweight.to(torch.float8_e4m3fn)

def forward(self, x):
qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale)
qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale)
output = fp8_gemm(
A=qinput,
A_scale=self.act_scale,
A_scale=self.input_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
Expand Down Expand Up @@ -247,7 +247,7 @@ def quantize_activations(
quantizer.weight,
quantizer.weight_scale,
quantizer.bias,
quantizer.act_scale,
quantizer.input_scale,
)
replace_module(model, name, static_proj)
del quantizer
Expand Down
24 changes: 12 additions & 12 deletions examples/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,23 @@ def __init__(self, qweight, weight_scale, bias):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.act_scale = None
self.input_scale = None
self.bias = bias

def forward(self, x):
# Dynamically quantize
qinput, x_act_scale = per_tensor_quantize(x)
qinput, x_input_scale = per_tensor_quantize(x)

# Update scale if needed.
if self.act_scale is None:
self.act_scale = torch.nn.Parameter(x_act_scale)
elif x_act_scale > self.act_scale:
self.act_scale = torch.nn.Parameter(x_act_scale)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale)

# Pass quantized to next layer so it has realistic data.
output = fp8_gemm(
A=qinput,
A_scale=self.act_scale,
A_scale=self.input_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
Expand All @@ -111,11 +111,11 @@ def forward(self, x):


class FP8StaticLinear(torch.nn.Module):
def __init__(self, qweight, weight_scale, bias, act_scale=0.0):
def __init__(self, qweight, weight_scale, bias, input_scale=0.0):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False)
self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False)
self.bias = bias

def per_tensor_quantize(
Expand All @@ -129,10 +129,10 @@ def per_tensor_quantize(
return qweight.to(torch.float8_e4m3fn)

def forward(self, x):
qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale)
qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale)
output = fp8_gemm(
A=qinput,
A_scale=self.act_scale,
A_scale=self.input_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
Expand Down Expand Up @@ -216,7 +216,7 @@ def quantize_activations(model, calibration_tokens):
quantizer.weight,
quantizer.weight_scale,
quantizer.bias,
quantizer.act_scale,
quantizer.input_scale,
)
replace_module(model, name, static_proj)
del quantizer
Expand Down

0 comments on commit 12d2882

Please sign in to comment.