Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stochastic speculative sampling #5625

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,12 +497,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_sequences = std::stoi(argv[i]);
} else if (arg == "--p-accept" || arg == "-pa") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.p_accept = std::stof(argv[i]);
} else if (arg == "--p-split" || arg == "-ps") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -1020,7 +1014,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
Expand Down
3 changes: 1 addition & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ struct gpt_params {
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
float p_accept = 0.5f; // speculative decoding accept probability
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
Expand Down
74 changes: 74 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,72 @@ static llama_token llama_sampling_sample_impl(
return id;
}

static llama_token_data_array llama_sample_probability_distribution_impl(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const llama_sampling_params & params = ctx_sampling->params;

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const bool penalize_nl = params.penalize_nl;

auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;

// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);

// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;

// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}

if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
}

cur.clear();

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}

llama_token_data_array cur_p = { cur.data(), cur.size(), false };

// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];

llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);

if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}
}

llama_sample_softmax(ctx_main, &cur_p);
return cur_p;
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some way to reuse the code from llama_sampling_sample and avoid the duplication? Also, this function does not take into account the grammar - is this correct?

Copy link
Collaborator Author

@mscheong01 mscheong01 Mar 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A solution would be to have llama_sampling_sample utilize llama_sample_probability_distribution internally. However, this approach appeared too intrusive to the existing codebase at the time, which led to duplicating the code instead.

+ Yes it seems like I should apply grammar constraints here


llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
Expand All @@ -304,6 +370,14 @@ llama_token llama_sampling_sample(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}

llama_token_data_array llama_sampling_probability_distribution(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
}

void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
Expand Down
7 changes: 7 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ llama_token llama_sampling_sample(
struct llama_context * ctx_cfg,
int idx = 0);

// returns the probability that token of given id will be sampled
llama_token_data_array llama_sampling_probability_distribution(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);

void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
Expand Down
Loading
Loading