Skip to content

Commit

Permalink
[AMD][FP8] Using MI300 FP8 format on ROCm for block_quant (#12134)
Browse files Browse the repository at this point in the history
Signed-off-by: Gregory Shtrasberg <[email protected]>
  • Loading branch information
gshtras authored Jan 17, 2025
1 parent 54cacf0 commit b5b57e3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
33 changes: 33 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,15 @@ def create_weights(
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_rocm():
weight, weight_scale, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale)
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale,
requires_grad=False)
return
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
Expand Down Expand Up @@ -495,6 +504,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_rocm():
w13_weight, w13_weight_scale_inv, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale_inv,
layer.w13_input_scale)
w2_weight, w2_weight_scale_inv, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale_inv, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale_inv, requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
Expand Down
14 changes: 11 additions & 3 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import triton
import triton.language as tl

from vllm.platforms import current_platform


def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
Expand Down Expand Up @@ -33,11 +35,14 @@ def apply_w8a8_block_fp8_linear(


def input_to_float8(
x: torch.Tensor,
dtype: torch.dtype = torch.float8_e4m3fn
x: torch.Tensor,
dtype: Optional[torch.dtype] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values "
"with tensor-wise quantization."""
if dtype is None:
dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else torch.float8_e4m3fn)
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
Expand Down Expand Up @@ -125,7 +130,7 @@ def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.float8_e4m3fn,
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
Expand All @@ -140,6 +145,9 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
if dtype is None:
dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else torch.float8_e4m3fn)
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
Expand Down

0 comments on commit b5b57e3

Please sign in to comment.