From c7622394042b6337cb321e945768db7658c4dfcb Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 28 May 2024 13:48:51 -0700 Subject: [PATCH] Update quantize.py --- auto_fp8/quantize.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index ef7ff4d..c6e9099 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -61,8 +61,12 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): + if A.numel() == 0: + # Deal with empty tensors (triggeted by empty MoE experts) + return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device) + native_fp8_support = ( - torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) ) if native_fp8_support: need_reshape = A.dim() == 3 @@ -81,7 +85,9 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): bias=bias, ) if need_reshape: - output = output.reshape((batch_size, output.shape[0] // batch_size, output.shape[1])) + output = output.reshape( + batch_size, output.shape[0] // batch_size, output.shape[1] + ) else: output = torch.nn.functional.linear( A.to(out_dtype) * A_scale,