Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Mozer authored Feb 23, 2024
1 parent f311104 commit 3c1c809
Showing 1 changed file with 48 additions and 73 deletions.
121 changes: 48 additions & 73 deletions examples/talk-llama/talk-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,16 @@ struct whisper_params {

std::string person = "Georgi";
std::string bot_name = "LLaMA";
std::string xtts_voice = "emma_1";
std::string wake_cmd = "";
std::string heard_ok = "";
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_llama = "models/ggml-llama-7B.bin";
std::string speak = "./examples/talk-llama/speak";
std::string xtts_control_path = "c:\\DATA\\LLM\\xtts\\xtts_play_allowed.txt";
std::string xtts_url = "http://localhost:8020/";
std::string google_url = "http://localhost:8003/";
std::string prompt = "";
std::string fname_out;
std::string path_session = ""; // path to file for saving/loading model eval state
Expand Down Expand Up @@ -154,6 +158,10 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params)
else if (arg == "--top_k") { params.top_k = std::stof(argv[++i]); }
else if (arg == "--top_p") { params.top_p = std::stof(argv[++i]); }
else if (arg == "--repeat_penalty") { params.repeat_penalty = std::stof(argv[++i]); }
else if (arg == "--xtts-voice") { params.xtts_voice = argv[++i]; }
else if (arg == "--xtts-url") { params.xtts_url = argv[++i]; }
else if (arg == "--google-url") { params.google_url = argv[++i]; }
else if (arg == "--xtts-control-path") { params.xtts_control_path = argv[++i]; }
else if (arg == "--prompt-file") {
std::ifstream file(argv[++i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
Expand Down Expand Up @@ -211,6 +219,10 @@ void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params
fprintf(stderr, " --top_k N [%-7.2f] top_k \n", params.top_k);
fprintf(stderr, " --top_p N [%-7.2f] top_p \n", params.top_p);
fprintf(stderr, " --repeat_penalty N [%-7.2f] repeat_penalty \n", params.repeat_penalty);
fprintf(stderr, " --xtts-voice NAME [%-7s] xtts voice without .wav\n", params.xtts_voice.c_str());
fprintf(stderr, " --xtts-url TEXT [%-7s] xtts/silero server URL, with trailing slash\n", params.xtts_url.c_str());
fprintf(stderr, " --xtts-control-path FNAME [%-7s] path to xtts_play_allowed.txt", params.xtts_control_path.c_str());
fprintf(stderr, " --google-url TEXT [%-7s] langchain google-serper server URL, with /\n", params.google_url.c_str());
fprintf(stderr, "\n");
}

Expand Down Expand Up @@ -310,7 +322,7 @@ void allow_xtts_file(std::string path, int xtts_play_allowed) {
bool doesExistAndIsReadable{readStream.good()};

if(!doesExistAndIsReadable){
printf("%s file not found", path.c_str());
//printf("%s file not found", path.c_str());
}

std::getline(readStream, singleLine);
Expand Down Expand Up @@ -492,34 +504,9 @@ std::string send_curl(std::string url)

return readBuffer;
}
/*
void parse_url(const string& raw_url) //no boost
{
std::string path,domain,x,protocol,port,query;
int offset = 0;
size_t pos1,pos2,pos3,pos4;
x = _trim(raw_url);
offset = offset==0 && x.compare(0, 8, "https://")==0 ? 8 : offset;
offset = offset==0 && x.compare(0, 7, "http://" )==0 ? 7 : offset;
pos1 = x.find_first_of('/', offset+1 );
path = pos1==string::npos ? "" : x.substr(pos1);
domain = string( x.begin()+offset, pos1 != string::npos ? x.begin()+pos1 : x.end() );
path = (pos2 = path.find("#"))!=string::npos ? path.substr(0,pos2) : path;
port = (pos3 = domain.find(":"))!=string::npos ? domain.substr(pos3+1) : "";
domain = domain.substr(0, pos3!=string::npos ? pos3 : domain.length());
protocol = offset > 0 ? x.substr(0,offset-3) : "";
query = (pos4 = path.find("?"))!=string::npos ? path.substr(pos4+1) : "";
path = pos4!=string::npos ? path.substr(0,pos4) : path;
cout << "[" << raw_url << "]" << endl;
cout << "protocol: " << protocol << endl;
cout << "domain: " << domain << endl;
cout << "port: " << port << endl;
cout << "path: " << path << endl;
cout << "query: " << query << endl;
}
*/

// send post without waiting for reply
// not used
std::string socket_post(const std::string &url, const std::map<std::string, std::string>& params)
{
printf(" in socket_post\n ");
Expand Down Expand Up @@ -663,28 +650,18 @@ std::string socket_post(const std::string &url, const std::map<std::string, std:

// async curl, but it's still blocking for some reason
// doesn't wait for responce
void send_tts_async(std::string text, std::string speaker_wav="emma_1", std::string language="en")
void send_tts_async(std::string text, std::string speaker_wav="emma_1", std::string language="en", std::string tts_url="http://localhost:8020/")
{
std::string url = "http://localhost:8020/tts_to_audio/";
//printf(" in send_tts_async: %s; size: %d\n",text.c_str(), text.size());
if (text.size() && text != "." && text != "," && text != "!" && text != "\n")
{
trim(text);
text = ::replace(text, "\r", "");
text = ::replace(text, "\n", " ");
text = ::replace(text, "\"", "");
//printf("send_tts_async sending, size:%d\n", text.size());

//std::map<std::string, std::string> params={{"text", text}, {"speaker_wav", speaker_wav}, {"language", language}};
//send_curl_json(url, params);
//socket_post(url, params);
//printf("after socket_post\n");

//std::wstring w_text = std::wstring_convert<std::codecvt_utf8<wchar_t>>().from_bytes(text.c_str());
//printf("trying to w_text %s, size:%d\n", w_text, w_text.size());
//text = console::UTF16toUTF8(w_text);
//printf("trying to curl text %s, size:%d\n", text, text.size());

tts_url= tts_url + "tts_to_audio/";
//printf("send_tts_async sending, url: %s\n", tts_url.c_str());
//for (char ch : tts_url) printf("%X ", ch);
//printf("\n");


CURL *http_handle;
Expand All @@ -696,7 +673,7 @@ void send_tts_async(std::string text, std::string speaker_wav="emma_1", std::str
//fprintf(stdout, " [data (%s)]\n", data.c_str());

curl_easy_setopt(http_handle, CURLOPT_HTTPHEADER, curl_slist_append(nullptr, "Content-Type:application/json"));
curl_easy_setopt(http_handle, CURLOPT_URL, "http://localhost:8020/tts_to_audio/");
curl_easy_setopt(http_handle, CURLOPT_URL, tts_url.c_str());
curl_easy_setopt(http_handle, CURLOPT_POSTFIELDS, data.c_str());
curl_easy_setopt(http_handle, CURLOPT_VERBOSE, 0L);

Expand Down Expand Up @@ -724,18 +701,6 @@ void send_tts_async(std::string text, std::string speaker_wav="emma_1", std::str
}
}


void thread_fn( std::string text/*, std::string language="en" */) {
if (text.size())
{
fprintf(stdout, " [before thread, [%d] %s]\n", text.c_str());
send_tts_async(text, "emma_1", "en");
fprintf(stdout," [after thread] ");
//text_to_speak_arr[thread_i-1] = "";
}
}


const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";

const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
Expand Down Expand Up @@ -765,13 +730,19 @@ int run(int argc, const char ** argv) {
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
printf("in wmain 2\n");

if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}

const std::string fileName{params.xtts_control_path};
std::ifstream readStream{fileName};
if(!readStream.good()){
printf("Warning: %s file not found, xtts wont stop on user speech without it\n", params.xtts_control_path.c_str());
}
readStream.close();

// whisper init

Expand Down Expand Up @@ -845,7 +816,9 @@ int run(int argc, const char ** argv) {
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;

const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", params.bot_name);
std::string prompt_whisper;
if (params.language == "ru") std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", "Анна"); // Алиса is bad
else std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", params.bot_name);

// construct the initial prompt for LLaMA inference
std::string prompt_llama = params.prompt.empty() ? k_prompt_llama : params.prompt;
Expand Down Expand Up @@ -998,8 +971,7 @@ int run(int argc, const char ** argv) {
// reverse prompts for detecting when it's time to stop speaking
std::vector<std::string> antiprompts = {
params.person + chat_symb,
"\n",
"Sergey:",
"\n"
};

// main loop
Expand All @@ -1025,7 +997,7 @@ int run(int argc, const char ** argv) {
{
// user has started speaking, xtts cannot play
//fprintf(stdout, "%s: Speech start! ...\n", __func__);
allow_xtts_file("c:\\DATA\\LLM\\xtts\\xtts_play_allowed.txt", 0);
allow_xtts_file(params.xtts_control_path, 0);
}
if (vad_result >= 2 || force_speak) // speech ended
{
Expand Down Expand Up @@ -1096,7 +1068,9 @@ int run(int argc, const char ** argv) {
text_heard = RemoveTrailingCharacters(text_heard, ',');
text_heard = RemoveTrailingCharacters(text_heard, '.');
if (text_heard[0] == '.') text_heard.erase(0, 1);
if (text_heard == "!" || text_heard == "." || text_heard == "Sil" || text_heard == "Okay" || text_heard == "Okay." || text_heard == "Thank you." || text_heard == "Thank you" || text_heard == "Thanks." || text_heard == "Bye." || text_heard == "Thank you for listening." || text_heard == "К" || text_heard == "Спасибо" || text_heard == params.bot_name || text_heard == "*Звук!*" || text_heard == "Р" || text_heard.find("Редактор субтитров")!= std::string::npos || text_heard.find("можешь это сделать")!= std::string::npos || text_heard.find("Как дела?")!= std::string::npos) text_heard = "";
if (text_heard[0] == '!') text_heard.erase(0, 1);
trim(text_heard);
if (text_heard == "!" || text_heard == "." || text_heard == "Sil" || text_heard == "Okay" || text_heard == "Okay." || text_heard == "Thank you." || text_heard == "Thank you" || text_heard == "Thanks." || text_heard == "Bye." || text_heard == "Thank you for listening." || text_heard == "К" || text_heard == "Спасибо" || text_heard == params.bot_name || text_heard == "*Звук!*" || text_heard == "Р" || text_heard.find("Редактор субтитров")!= std::string::npos || text_heard.find("можешь это сделать")!= std::string::npos || text_heard.find("Как дела?")!= std::string::npos || text_heard.find("Это")!= std::string::npos || text_heard.find("Добро пожаловать")!= std::string::npos) text_heard = "";
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), ""); // trailing whitespace


Expand All @@ -1107,6 +1081,7 @@ int run(int argc, const char ** argv) {
text_heard_trimmed = text_heard; // no periods or spaces
trim(text_heard_trimmed);
if (text_heard_trimmed[0] == '.') text_heard_trimmed.erase(0, 1);
if (text_heard_trimmed[0] == '!') text_heard_trimmed.erase(0, 1);
if (text_heard_trimmed[text_heard_trimmed.length() - 1] == '.' || text_heard_trimmed[text_heard_trimmed.length() - 1] == '!') text_heard_trimmed.erase(text_heard_trimmed.length() - 1, 1);
trim(text_heard_trimmed);
text_heard_trimmed = LowerCase(text_heard_trimmed); // not working right with utf and russian
Expand All @@ -1124,7 +1099,7 @@ int run(int argc, const char ** argv) {
else if (text_heard_trimmed.find("stop") != std::string::npos || text_heard_trimmed.find("Стоп") != std::string::npos || text_heard_trimmed.find("Остановись") != std::string::npos || text_heard_trimmed.find("тановись") != std::string::npos || text_heard_trimmed.find("Хватит") != std::string::npos || text_heard_trimmed.find("Становись") != std::string::npos) user_command = "stop";

// user has finished speaking, xtts can play
allow_xtts_file("c:\\DATA\\LLM\\xtts\\xtts_play_allowed.txt", 1);
allow_xtts_file(params.xtts_control_path, 1);

if (user_command.size() && !new_command_allowed && std::time(0)-last_command_time >= 1)
{
Expand All @@ -1149,7 +1124,7 @@ int run(int argc, const char ** argv) {
n_past -= rollback_num;
text_heard = text_heard_prev;
text_heard_trimmed = "";
//send_tts_async("Regenerating", "ux", params.language);
send_tts_async("Regenerating", params.xtts_voice, params.language, params.xtts_url);
new_command_allowed = 0;
}
}
Expand Down Expand Up @@ -1184,7 +1159,7 @@ int run(int argc, const char ** argv) {
n_past -= rollback_num;
text_heard = "";
text_heard_trimmed = "";
send_tts_async("Deleted", "ux", params.language);
send_tts_async("Deleted", params.xtts_voice, params.language, params.xtts_url);
last_command_time = std::time(0);
//printf("last_command_time: %d\n", last_command_time);
new_command_allowed = 0;
Expand Down Expand Up @@ -1216,14 +1191,14 @@ int run(int argc, const char ** argv) {
n_past -= rollback_num;
text_heard = "";
text_heard_trimmed = "";
send_tts_async("Reset whole context.", "ux", params.language);
send_tts_async("Reset whole context", params.xtts_voice, params.language, params.xtts_url);
new_command_allowed = 0;
}
}
else
{
printf("Nothing to reset more\n");
send_tts_async("Nothing to reset more", "ux", params.language);
printf("Nothing to reset more\n");
send_tts_async("Nothing to reset more", params.xtts_voice, params.language, params.xtts_url);
}
}
audio.clear();
Expand All @@ -1239,7 +1214,7 @@ int run(int argc, const char ** argv) {
std::string q = ParseCommandAndGetKeyword(text_heard_trimmed);
if (q.size())
{
std::string url = "http://localhost:8003/google?q="+UrlEncode(q);
std::string url = params.google_url+"google?q="+UrlEncode(q);
google_resp = send_curl(url);
if (google_resp.size())
{
Expand All @@ -1257,7 +1232,7 @@ int run(int argc, const char ** argv) {
{
threads.emplace_back([&] // creates and starts a thread
{
if (google_resp.size()) send_tts_async(google_resp, "ux", "ru"); //params.language
if (google_resp.size()) send_tts_async(google_resp, params.xtts_voice, params.language, params.xtts_url);
});
thread_i++;
}
Expand Down Expand Up @@ -1467,7 +1442,7 @@ int run(int argc, const char ** argv) {
{
if (text_to_speak_arr[thread_i-1].size())
{
send_tts_async(text_to_speak_arr[thread_i-1], "emma_1", params.language);
send_tts_async(text_to_speak_arr[thread_i-1], params.xtts_voice, params.language, params.xtts_url);
text_to_speak_arr[thread_i-1] = "";
}
});
Expand All @@ -1485,7 +1460,7 @@ int run(int argc, const char ** argv) {
{
// user has started speaking, xtts cannot play
fprintf(stdout, " [Speech detected! Aborting ...]\n");
allow_xtts_file("c:\\DATA\\LLM\\xtts\\xtts_play_allowed.txt", 0);
allow_xtts_file(params.xtts_control_path, 0);
done = true; // generation stops
break;
}
Expand Down Expand Up @@ -1539,7 +1514,7 @@ int run(int argc, const char ** argv) {
{
if (text_to_speak_arr[thread_i-1].size())
{
send_tts_async(text_to_speak_arr[thread_i-1], "emma_1", params.language);
send_tts_async(text_to_speak_arr[thread_i-1], params.xtts_voice, params.language, params.xtts_url);
text_to_speak_arr[thread_i-1] = "";
}
});
Expand Down

0 comments on commit 3c1c809

Please sign in to comment.