Skip to content

Commit

Permalink
start prompt wasn't evald with default batch_size
Browse files Browse the repository at this point in the history
bug fix
in versions 0.1.3-0.1.5 start prompt was not eval'd correctly with new default batch_size of 64.
  • Loading branch information
Mozer authored Apr 30, 2024
1 parent 1e58b10 commit 1553710
Showing 1 changed file with 12 additions and 54 deletions.
66 changes: 12 additions & 54 deletions examples/talk-llama/talk-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,42 +421,6 @@ void allow_xtts_file(std::string path, int xtts_play_allowed) {
}
}

// writes to temp file 0 or 1
// @path: not used anymore, we store in temp
// @xtts_play_allowed: 0=dont play xtts, 1=xtts can play
/*
void allow_xtts_file(std::string path, int xtts_play_allowed) {
std::filesystem::path filePath = getTempDir() / "xtts_play_allowed.txt";
// Check if the file exists and read the current value
std::string currentValue;
if (std::filesystem::exists(filePath)) {
std::ifstream file(filePath);
if (file.is_open()) {
std::getline(file, currentValue);
file.close();
}
}
std::string xtts_play_allowed_str = std::to_string(xtts_play_allowed);
// Only update the file if the value is different
if (currentValue != xtts_play_allowed_str) {
std::ofstream file(filePath, std::ios::trunc);
if (file.is_open()) {
file << xtts_play_allowed_str;
file.close();
return true;
} else {
printf("Unable to open file: %s\n", filePath)
return false;
}
}
return true;
}
*/

// trim from start (in place)
inline void ltrim(std::string &s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
Expand Down Expand Up @@ -1051,18 +1015,7 @@ int run(int argc, const char ** argv) {
exit(0);
}

//const std::string fileName{params.xtts_control_path};
//std::ifstream readStream{fileName};
//if(!readStream.good()){
// printf("Warning: %s file not found, xtts wont stop on user speech without it\n", params.xtts_control_path.c_str());
// readStream.close();
//}
//else // control file is ok
//{
// readStream.close();
// allow_xtts_file(params.xtts_control_path, 1); // xtts can play
//}

allow_xtts_file(params.xtts_control_path, 1); // xtts can play

// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
Expand Down Expand Up @@ -1219,15 +1172,18 @@ int run(int argc, const char ** argv) {
float llama_start_time = get_current_time_ms();

// NEW prompt eval
int n_past = 0;
// Calculate the number of chunks needed
size_t num_chunks = (embd_inp.size() + lcparams.n_batch - 1) / lcparams.n_batch;
// Iterate through the chunks and evaluate them
for (size_t i = 0; i < num_chunks; i++) {
// Calculate the start and end indices for the current chunk
size_t start_idx = i * lcparams.n_batch;
size_t end_idx = std::min((i + 1) * lcparams.n_batch, embd_inp.size());
size_t chunk_size = end_idx - start_idx;
// Evaluate the current chunk
llama_eval(ctx_llama, embd_inp.data() + start_idx, end_idx - start_idx, 0);
llama_eval(ctx_llama, embd_inp.data() + start_idx, chunk_size, n_past);
n_past += chunk_size;
}

float llama_end_time = get_current_time_ms();
Expand Down Expand Up @@ -1288,7 +1244,7 @@ int run(int argc, const char ** argv) {
const int n_keep = embd_inp.size();
const int n_ctx = llama_n_ctx(ctx_llama);

int n_past = n_keep;
n_past = n_keep;
int n_prev = 64;
std::vector<int> past_prev_arr{};
int n_past_prev = 0; // token count that was before the last answer
Expand Down Expand Up @@ -1492,7 +1448,7 @@ int run(int argc, const char ** argv) {
} else {
text_heard += words[i] + " ";
}
}
}
if (params.print_energy) fprintf(stdout, " [text_heard: (%s)]\n", text_heard.c_str());

// check if audio starts with the wake-up command if enabled
Expand Down Expand Up @@ -1712,7 +1668,7 @@ int run(int argc, const char ** argv) {
{
printf(" [Resetting context of %d tokens.]\n", embd_inp.size());
embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
n_past = embd_inp.size();
n_past = 0;
// NEW prompt eval
// Calculate the number of chunks needed
size_t num_chunks = (embd_inp.size() + lcparams.n_batch - 1) / lcparams.n_batch;
Expand All @@ -1721,8 +1677,10 @@ int run(int argc, const char ** argv) {
// Calculate the start and end indices for the current chunk
size_t start_idx = i * lcparams.n_batch;
size_t end_idx = std::min((i + 1) * lcparams.n_batch, embd_inp.size());
size_t chunk_size = end_idx - start_idx;
// Evaluate the current chunk
llama_eval(ctx_llama, embd_inp.data() + start_idx, end_idx - start_idx, 0);
llama_eval(ctx_llama, embd_inp.data() + start_idx, chunk_size, n_past);
n_past += chunk_size;
}
printf(" [Context is now %d/%d tokens]\n", embd_inp.size(), params.ctx_size);

Expand Down Expand Up @@ -1927,7 +1885,7 @@ int run(int argc, const char ** argv) {

// NEW prompt eval
size_t total_size = embd.size();
size_t batch_size = lcparams.n_batch;
//size_t batch_size = lcparams.n_batch;
// Split the input embeddings into smaller chunks
for (size_t i = 0; i < total_size; i += lcparams.n_batch) {
size_t chunk_size = (total_size - i < lcparams.n_batch) ? (total_size - i) : lcparams.n_batch;
Expand Down

0 comments on commit 1553710

Please sign in to comment.