Skip to content

Commit

Permalink
server : fix parallel speculative decoding (#10513)
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov authored Nov 26, 2024
1 parent 811872a commit 84e1c33
Showing 1 changed file with 31 additions and 32 deletions.
63 changes: 31 additions & 32 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2267,50 +2267,49 @@ struct server_context {
continue; // continue loop of slots
}

llama_token id;
llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);

{
completion_token_output result;

id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
slot.i_batch = -1;

slot.i_batch = -1;
common_sampler_accept(slot.smpl, id, true);

common_sampler_accept(slot.smpl, id, true);

slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}

result.tok = id;
completion_token_output result;
result.tok = id;

const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const auto * cur_p = common_sampler_get_candidates(slot.smpl);

for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
result.probs.push_back({
cur_p->data[i].id,
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
}
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
result.probs.push_back({
cur_p->data[i].id,
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
}

if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
continue;
}
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
continue;
}
}

// check if the slot supports speculative decoding
if (!slot.can_speculate()) {
// do speculative decoding
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
continue;
}

llama_token id = slot.sampled;

struct common_speculative_params params_spec;
params_spec.n_draft = slot.params.speculative.n_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
Expand Down

0 comments on commit 84e1c33

Please sign in to comment.