Skip to content

Commit

Permalink
Update quantize.py
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored May 28, 2024
1 parent c4a9594 commit c762239
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:


def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
if A.numel() == 0:
# Deal with empty tensors (triggeted by empty MoE experts)
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)

native_fp8_support = (
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
)
if native_fp8_support:
need_reshape = A.dim() == 3
Expand All @@ -81,7 +85,9 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
bias=bias,
)
if need_reshape:
output = output.reshape((batch_size, output.shape[0] // batch_size, output.shape[1]))
output = output.reshape(
batch_size, output.shape[0] // batch_size, output.shape[1]
)
else:
output = torch.nn.functional.linear(
A.to(out_dtype) * A_scale,
Expand Down

0 comments on commit c762239

Please sign in to comment.