diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 8f28580961666..9edebd9c3caae 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -16,13 +16,6 @@ __device__ __forceinline__ T silu_kernel(const T& x) { return (T)(((float)x) / (1.0f + expf((float)-x))); } -__device__ __forceinline__ FP8_TYPE -scaled_fp8_conversion(float const val, float const inverted_scale) { - float x = val * inverted_scale; - float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); - return static_cast(r); -} - // Activation and gating kernel template. template __global__ void act_and_mul_quant_kernel( @@ -74,8 +67,8 @@ __global__ void act_and_mul_quant_kernel( #pragma unroll for (int i = 0; i < elems_per_128bit_load; i++) { - out_vec[i] = - scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i], inverted_scale); + out_vec[i] = scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i], + inverted_scale); } out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); @@ -87,7 +80,7 @@ __global__ void act_and_mul_quant_kernel( idx += blockDim.x) { const scalar_t x = VLLM_LDG(&x_ptr[idx]); const scalar_t y = VLLM_LDG(&y_ptr[idx]); - out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale); + out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale); } } } diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 992108f7873e7..871e388ec2212 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -22,11 +22,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default -SILU_MUL_QUANT_OPS = { - "static_fp8": [ - torch.ops._C.silu_and_mul_quant.default, - ], -} +SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default prompts = [ "Hello, my name is", "The president of the United States is", @@ -90,10 +86,10 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, # and replaced by fused quantized ops in RMS_QUANT_OPS. rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)] ] if do_fusion else [RMS_OP] - silu_mul_ops = SILU_MUL_QUANT_OPS[ - "static_fp8"] if do_fusion and quant_key == kFp8StaticTensorSym else [ - SILU_MUL_OP - ] + silu_mul_ops = SILU_MUL_QUANT_OP if do_fusion and \ + quant_key == kFp8StaticTensorSym else [ + SILU_MUL_OP + ] ops = OPS_IN_MODEL + rms_ops + silu_mul_ops