Skip to content

Commit

Permalink
Fix sampling errors due to float rounding errors on large-vocab models
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Jul 13, 2024
1 parent 44dd05b commit b8a5bcf
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
7 changes: 6 additions & 1 deletion exllamav2/exllamav2_ext/cpp/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,12 @@ int multinomial_cpu
while (true)
{
if (accum >= random) break;
if (idx == num_candidates - 1) break;
if (idx == num_candidates - 1)
{
// Roll back in case the sampled probability is exactly zero
while (idx > 0 && temp_probs[idx] == 0.0f) idx--;
break;
}
idx++;
accum += temp_probs[idx];
}
Expand Down
1 change: 1 addition & 0 deletions exllamav2/exllamav2_ext/cpp/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF4(__x, __y, __z, __w) printf("%s, %s, %s, %s: %f, %f, %f, %f\n", #__x, #__y, #__z, #__w, __x, __y, __z, __w)
#define DBGIF(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __y)

#define TIME_START \
Expand Down
19 changes: 18 additions & 1 deletion exllamav2/exllamav2_ext/ext_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,24 @@ std::vector<float> sample_basic
random_s = powf(random, expf(-skew));
}

multinomial_cpu(num_candidates, temp_probs, temp_indices, random_s);
// {
// float sum = 0.0f;
// float pmin = temp_probs[0];
// float pmax = pmin;
// for (int i = 0; i < num_candidates; ++i)
// {
// if (temp_probs[i] < pmin) pmin = temp_probs[i];
// if (temp_probs[i] > pmax) pmax = temp_probs[i];
// sum += temp_probs[i];
// }
// DBGF4(pmin, pmax, sum, random_s);
// }

// Scale random sampling point a little to account for FP32 rounding errors during softmax. Probs
// can potentially sum to slightly less than 1 for large-vocab models
float random_s_adj = random_s * 0.9998;

multinomial_cpu(num_candidates, temp_probs, temp_indices, random_s_adj);

output_tokens[i][0] = temp_indices[0];
output_probs[i][0] = temp_probs[0];
Expand Down

0 comments on commit b8a5bcf

Please sign in to comment.