Skip to content

Commit

Permalink
Fix weight name
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jun 14, 2024
1 parent 211c4fc commit 96cc0b0
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
class FP8DynamicLinear(torch.nn.Module):
def __init__(
self,
qweight: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.nn.Parameter,
):
super().__init__()
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.bias = bias

Expand All @@ -123,7 +123,7 @@ def forward(self, x):
output = fp8_gemm(
A=qinput,
A_scale=x_scale,
B=self.qweight,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
Expand All @@ -135,13 +135,13 @@ def forward(self, x):
class FP8StaticLinearQuantizer(torch.nn.Module):
def __init__(
self,
qweight: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.nn.Parameter,
quantize_output: bool = False,
):
super().__init__()
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.bias = bias
self.input_scale = None
Expand All @@ -157,7 +157,7 @@ def forward(self, x):
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
B=self.qweight,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
Expand All @@ -179,14 +179,14 @@ def forward(self, x):
class FP8StaticLinear(torch.nn.Module):
def __init__(
self,
qweight: torch.nn.Parameter,
weight: torch.nn.Parameter,
weight_scale: torch.nn.Parameter,
bias: torch.nn.Parameter,
input_scale: torch.nn.Parameter,
output_scale: Optional[torch.nn.Parameter] = None,
):
super().__init__()
self.qweight = qweight
self.weight = weight
self.weight_scale = weight_scale
self.bias = bias
self.input_scale = input_scale
Expand All @@ -197,7 +197,7 @@ def forward(self, x):
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
B=self.qweight,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
Expand Down Expand Up @@ -236,7 +236,7 @@ def quantize_weights(
quant_weight, weight_scale = per_tensor_quantize(linear.weight.clone())
bias = linear.bias.clone() if linear.bias is not None else None
quant_linear = FP8DynamicLinear(
qweight=quant_weight, weight_scale=weight_scale, bias=bias
weight=quant_weight, weight_scale=weight_scale, bias=bias
)
replace_module(model, name, quant_linear)
del linear.weight
Expand All @@ -258,7 +258,7 @@ def quantize_activations(
):
continue
quantizer = FP8StaticLinearQuantizer(
qweight=dynamic_quant_linear.qweight,
weight=dynamic_quant_linear.weight,
weight_scale=dynamic_quant_linear.weight_scale,
bias=dynamic_quant_linear.bias,
quantize_output=(
Expand Down Expand Up @@ -287,7 +287,7 @@ def quantize_activations(
):
continue
static_proj = FP8StaticLinear(
qweight=quantizer.qweight,
weight=quantizer.weight,
weight_scale=quantizer.weight_scale,
bias=quantizer.bias,
input_scale=quantizer.input_scale,
Expand Down

0 comments on commit 96cc0b0

Please sign in to comment.