Skip to content

Commit

Permalink
[Bugfix][Kernel] Use int64_t for indices in fp8 quant kernels (vllm-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored Jul 22, 2024
1 parent 739b61a commit fea59c7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
12 changes: 6 additions & 6 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);

int const num_vec_elems = num_elems >> 2;
int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;

#pragma unroll 4
for (int i = tid; i < num_vec_elems; i += step) {
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
Expand All @@ -116,7 +116,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}

Expand All @@ -134,10 +134,10 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);

int const num_vec_elems = num_elems >> 2;
int64_t const num_vec_elems = num_elems >> 2;

#pragma unroll 4
for (int i = tid; i < num_vec_elems; i += step) {
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;

Expand All @@ -153,7 +153,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
Expand Down
25 changes: 25 additions & 0 deletions tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,28 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
assert torch.allclose(ref_scale, ops_scale)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))


# Regression test for a case with large activations where an int32 index cannot
# represent the number of elements.
@torch.inference_mode()
@pytest.mark.parametrize("seed", SEEDS)
def test_fp8_quant_large(seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
hidden_size = 1152 # Smallest hidden_size to reproduce the error
dtype = torch.bfloat16

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale = ref_dynamic_per_tensor_fp8_quant(x)
ops_out, _ = ops.scaled_fp8_quant(x, scale)

# Minimize memory footprint in this test by freeing x and upconverting
# the outputs in place. (torch.allclose does not support fp8)
del x
ref_out = ref_out.to(dtype=dtype)
ops_out = ops_out.to(dtype=dtype)

assert torch.allclose(ref_out, ops_out)

0 comments on commit fea59c7

Please sign in to comment.