Skip to content

Commit

Permalink
speculative : ensure draft and target model vocab matches (ggerganov#…
Browse files Browse the repository at this point in the history
…3812)

* speculative: Ensure draft and target model vocab matches

* Tolerate small differences when checking dft vs tgt vocab
  • Loading branch information
KerfuffleV2 authored Oct 27, 2023
1 parent 6d459cb commit 41aee4d
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include <string>
#include <vector>

#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5

struct seq_draft {
bool active = false;
bool drafting = false;
Expand Down Expand Up @@ -64,6 +67,33 @@ int main(int argc, char ** argv) {
params.n_gpu_layers = params.n_gpu_layers_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);

{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;

if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1;
}

for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
llama_token_to_piece(ctx_tgt, i).c_str(),
llama_token_to_piece(ctx_dft, i).c_str());
return 1;
}
}
}

// tokenize the prompt
std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
Expand Down Expand Up @@ -227,6 +257,7 @@ int main(int argc, char ** argv) {
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);

llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode (ctx_dft, batch_dft);

++n_past_dft;
Expand Down Expand Up @@ -370,7 +401,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
}

//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
++n_past_tgt;
}
Expand Down

0 comments on commit 41aee4d

Please sign in to comment.