diff --git a/cpp/rn-audioutils.cpp b/cpp/rn-audioutils.cpp index 292a704..bde65fd 100644 --- a/cpp/rn-audioutils.cpp +++ b/cpp/rn-audioutils.cpp @@ -3,41 +3,24 @@ namespace rnaudioutils { -std::vector concat_short_buffers(const std::vector& buffers, const std::vector& slice_n_samples) { - std::vector output_data; - - for (size_t i = 0; i < buffers.size(); i++) { - int size = slice_n_samples[i]; // Number of shorts - short* slice = buffers[i]; - - // Copy each short as two bytes - for (int j = 0; j < size; j++) { - output_data.push_back(static_cast(slice[j] & 0xFF)); // Lower byte - output_data.push_back(static_cast((slice[j] >> 8) & 0xFF)); // Higher byte - } - } - - return output_data; -} - -std::vector remove_trailing_zeros(const std::vector& audio_data) { - auto last = std::find_if(audio_data.rbegin(), audio_data.rend(), [](uint8_t byte) { return byte != 0; }); - return std::vector(audio_data.begin(), last.base()); +void append_wav_data(const short* data, const int n_samples, const std::string& file) { + std::ofstream output(file, std::ios::binary | std::ios::app); + output.write(reinterpret_cast(data), n_samples * sizeof(short)); + output.close(); } -void save_wav_file(const std::vector& raw, const std::string& file) { - std::vector data = remove_trailing_zeros(raw); - - std::ofstream output(file, std::ios::binary); +void add_wav_header_to_file(const std::string& file, const int data_size) { + std::ofstream output(file, std::ios::binary | std::ios::app); if (!output.is_open()) { RNWHISPER_LOG_ERROR("Failed to open file for writing: %s\n", file.c_str()); return; } - // WAVE header + output.seekp(0, std::ios::beg); + output.write("RIFF", 4); - int32_t chunk_size = 36 + static_cast(data.size()); + int32_t chunk_size = 36 + static_cast(data_size); output.write(reinterpret_cast(&chunk_size), sizeof(chunk_size)); output.write("WAVE", 4); output.write("fmt ", 4); @@ -56,13 +39,10 @@ void save_wav_file(const std::vector& raw, const std::string& file) { short bits_per_sample = 16; output.write(reinterpret_cast(&bits_per_sample), sizeof(bits_per_sample)); output.write("data", 4); - int32_t sub_chunk2_size = static_cast(data.size()); + int32_t sub_chunk2_size = static_cast(data_size); output.write(reinterpret_cast(&sub_chunk2_size), sizeof(sub_chunk2_size)); - output.write(reinterpret_cast(data.data()), data.size()); output.close(); - - RNWHISPER_LOG_INFO("Saved audio file: %s\n", file.c_str()); } } // namespace rnaudioutils diff --git a/cpp/rn-audioutils.h b/cpp/rn-audioutils.h index 9e49976..7174854 100644 --- a/cpp/rn-audioutils.h +++ b/cpp/rn-audioutils.h @@ -8,7 +8,7 @@ namespace rnaudioutils { -std::vector concat_short_buffers(const std::vector& buffers, const std::vector& slice_n_samples); -void save_wav_file(const std::vector& raw, const std::string& file); +void append_wav_data(const short* data, const int n_samples, const std::string& file); +void add_wav_header_to_file(const std::string& file, const int data_size); } // namespace rnaudioutils diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index ab98be6..62c0e27 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -100,6 +100,7 @@ void job::put_pcm_data(short* data, int slice_index, int n_samples, int n) { for (int i = 0; i < n; i++) { pcm[i + n_samples] = data[i]; } + pcm_data_size += n; } float* job::pcm_slice_to_f32(int slice_index, int size) { diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 46adbb9..ea4b2b7 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -33,6 +33,7 @@ struct job { int audio_slice_sec = 0; float audio_min_sec = 0; const char* audio_output_path = nullptr; + int pcm_data_size = 0; std::vector pcm_slices; void set_realtime_params(vad_params vad, int sec, int slice_sec, float min_sec, const char* output_path); bool vad_simple(int slice_index, int n_samples, int n); diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 5b12c7e..dce6925 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -232,10 +232,9 @@ void AudioInputCallback(void * inUserData, - (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NSDictionary*)result { // Save wav if needed if (state->job->audio_output_path != nullptr) { - // TODO: Append in real time so we don't need to keep all slices & also reduce memory usage - rnaudioutils::save_wav_file( - rnaudioutils::concat_short_buffers(state->job->pcm_slices, state->sliceNSamples), - state->job->audio_output_path + rnaudioutils::add_wav_header_to_file( + state->job->audio_output_path, + state->job->pcm_data_size ); } state->transcribeHandler(state->job->job_id, @"end", result); @@ -284,6 +283,14 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state { state->nSamplesTranscribing == nSamplesOfIndex && state->transcribeSliceIndex != state->sliceIndex ) { + if (state->job->audio_output_path != nullptr) { + rnaudioutils::append_wav_data( + state->job->pcm_slices[state->transcribeSliceIndex], + state->sliceNSamples[state->transcribeSliceIndex], + state->job->audio_output_path + ); + } + // TODO: Clean up the previous slice state->transcribeSliceIndex++; state->nSamplesTranscribing = 0; }