Skip to content

Commit

Permalink
review comments and format
Browse files Browse the repository at this point in the history
Signed-off-by: Sage Moore <[email protected]>
  • Loading branch information
SageMoore committed Dec 20, 2024
1 parent ec1290a commit 008b725
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 19 deletions.
13 changes: 3 additions & 10 deletions csrc/quantization/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@ __device__ __forceinline__ T silu_kernel(const T& x) {
return (T)(((float)x) / (1.0f + expf((float)-x)));
}

__device__ __forceinline__ FP8_TYPE
scaled_fp8_conversion(float const val, float const inverted_scale) {
float x = val * inverted_scale;
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<FP8_TYPE>(r);
}

// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_quant_kernel(
Expand Down Expand Up @@ -74,8 +67,8 @@ __global__ void act_and_mul_quant_kernel(

#pragma unroll
for (int i = 0; i < elems_per_128bit_load; i++) {
out_vec[i] =
scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i], inverted_scale);
out_vec[i] = scaled_fp8_conversion<true>(ACT_FN(x_vec[i]) * y_vec[i],
inverted_scale);
}

out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec);
Expand All @@ -87,7 +80,7 @@ __global__ void act_and_mul_quant_kernel(
idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale);
out_ptr[idx] = scaled_fp8_conversion<true>(ACT_FN(x) * y, inverted_scale);
}
}
}
Expand Down
14 changes: 5 additions & 9 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@

SILU_MUL_OP = torch.ops._C.silu_and_mul.default

SILU_MUL_QUANT_OPS = {
"static_fp8": [
torch.ops._C.silu_and_mul_quant.default,
],
}
SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
prompts = [
"Hello, my name is",
"The president of the United States is",
Expand Down Expand Up @@ -90,10 +86,10 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
] if do_fusion else [RMS_OP]
silu_mul_ops = SILU_MUL_QUANT_OPS[
"static_fp8"] if do_fusion and quant_key == kFp8StaticTensorSym else [
SILU_MUL_OP
]
silu_mul_ops = SILU_MUL_QUANT_OP if do_fusion and \
quant_key == kFp8StaticTensorSym else [
SILU_MUL_OP
]

ops = OPS_IN_MODEL + rms_ops + silu_mul_ops

Expand Down

0 comments on commit 008b725

Please sign in to comment.