Skip to content

Commit

Permalink
Add moonshine_live.exe to README and refactor audio handling in live …
Browse files Browse the repository at this point in the history
…example
  • Loading branch information
royshil committed Nov 6, 2024
1 parent ebcdc37 commit a75823c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 29 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ After building you should have a folder like so (e.g. on Windows):
./dist
├───bin
│ moonshine_example.exe
│ moonshine_live.exe
│ onnxruntime.dll
│ onnxruntime_providers_shared.dll
Expand Down
86 changes: 57 additions & 29 deletions examples/live.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
const int SAMPLE_RATE = 16000;
const int BUFFER_SIZE = 4096;

std::vector<float> convertToFloat(const std::vector<int16_t>& pcm_data)
void audioCallback(void* userdata, Uint8* stream, int len)
{
std::vector<float> float_data;
float_data.reserve(pcm_data.size());
for (const auto& pcm_sample : pcm_data)
{
float_data.push_back(static_cast<float>(pcm_sample) / 32768.0f);
}
return float_data;
std::vector<float>* buffer = static_cast<std::vector<float>*>(userdata);
float* samples = reinterpret_cast<float*>(stream);
int sample_count = len / sizeof(float);
buffer->insert(buffer->end(), samples, samples + sample_count);
}

void audioCallback(void* userdata, Uint8* stream, int len)
void listAudioDevices()
{
std::vector<int16_t>* buffer = static_cast<std::vector<int16_t>*>(userdata);
int16_t* samples = reinterpret_cast<int16_t*>(stream);
int sample_count = len / sizeof(int16_t);
buffer->insert(buffer->end(), samples, samples + sample_count);
int count = SDL_GetNumAudioDevices(SDL_TRUE); // SDL_TRUE for recording devices
std::cout << "Available recording devices:\n";
for (int i = 0; i < count; ++i)
{
const char* name = SDL_GetAudioDeviceName(i, SDL_TRUE);
std::cout << i << ": " << (name ? name : "Unknown Device") << "\n";
}
}

int main(int argc, char* argv[])
Expand Down Expand Up @@ -60,66 +60,94 @@ int main(int argc, char* argv[])

std::cout << "SDL initialized\n";

// List available devices
listAudioDevices();

// Set up audio capture
SDL_AudioSpec desired_spec;
SDL_AudioSpec obtained_spec;
SDL_zero(desired_spec);
desired_spec.freq = SAMPLE_RATE;
desired_spec.format = AUDIO_S16LSB;
desired_spec.format = AUDIO_F32;
desired_spec.channels = 1;
desired_spec.samples = BUFFER_SIZE;
desired_spec.callback = audioCallback;

std::vector<int16_t> audio_buffer;
std::vector<float> audio_buffer;
desired_spec.userdata = &audio_buffer;

if (SDL_OpenAudio(&desired_spec, &obtained_spec) < 0)
// Open the default recording device
SDL_AudioDeviceID dev = SDL_OpenAudioDevice(NULL, // device name (NULL for default)
SDL_TRUE, // is_capture (recording)
&desired_spec, // desired spec
&obtained_spec, // obtained spec
SDL_AUDIO_ALLOW_FORMAT_CHANGE);

if (dev == 0)
{
std::cerr << "Could not open audio: " << SDL_GetError() << "\n";
std::cerr << "Could not open audio device: " << SDL_GetError() << "\n";
SDL_Quit();
return 1;
}

std::cout << "Audio opened\n";
std::cout << "Audio device opened: " << SDL_GetAudioDeviceName(0, SDL_TRUE) << "\n";
// print the obtained spec
std::cout << "Obtained spec: " << obtained_spec.freq << " Hz, "
<< SDL_AUDIO_BITSIZE(obtained_spec.format) << " bits, "
<< (obtained_spec.channels == 1 ? "mono" : "stereo") << "\n";

// Start audio capture
SDL_PauseAudio(0);
SDL_PauseAudioDevice(dev, 0);

std::atomic<bool> running(true);
std::thread transcription_thread(
[&]()
{
std::cout << "Transcribing...\n";
size_t last_buffer_size = 0;
while (running)
{
if (audio_buffer.size() >= SAMPLE_RATE)
{
if (audio_buffer.size() == last_buffer_size)
{
// No new audio data
std::this_thread::sleep_for(std::chrono::milliseconds(100));
continue;
}
last_buffer_size = audio_buffer.size();

// Process audio buffer
std::vector<int16_t> buffer(audio_buffer.begin(),
audio_buffer.begin() + SAMPLE_RATE);
audio_buffer.erase(audio_buffer.begin(),
audio_buffer.begin() + SAMPLE_RATE);
std::vector<float> buffer(audio_buffer.begin(), audio_buffer.end());

// Convert audio buffer to float
std::vector<float> audio_samples = convertToFloat(buffer);
// Limit the buffer size to 10 seconds
if (audio_buffer.size() > 10 * SAMPLE_RATE)
{
audio_buffer.erase(audio_buffer.begin(), audio_buffer.end());
}

// Generate tokens
auto start = std::chrono::high_resolution_clock::now();
auto tokens = model.generate(audio_samples);
auto tokens = model.generate(buffer);
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration = end - start;

// Detokenize tokens
std::string result = model.detokenize(tokens);

// erase the last console line
std::cout << "\x1b[A";
// clear the line
std::cout << "\r\033[K";

std::cout << "Transcription: " << result << "\n";
std::cout << "Token generation took " << duration.count() << " seconds\n";
}
else
{
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
}
std::cout << "Transcription thread finished\n";
});

std::cout << "Recording... Press 'q' or 'ESC' to stop.\n";
Expand All @@ -137,13 +165,13 @@ int main(int argc, char* argv[])
}

// Stop audio capture
SDL_PauseAudio(1);
SDL_PauseAudioDevice(dev, 1);

// Wait for transcription thread to finish
transcription_thread.join();

// Clean up
SDL_CloseAudio();
SDL_CloseAudioDevice(dev);
SDL_Quit();
}
catch (const Ort::Exception& e)
Expand Down

0 comments on commit a75823c

Please sign in to comment.