Skip to content

Commit

Permalink
feat(cpp): move audio utils & save audio
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Dec 8, 2023
1 parent 54fea10 commit 6f95686
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 178 deletions.
1 change: 1 addition & 0 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ set(
${RNWHISPER_LIB_DIR}/ggml-backend.c
${RNWHISPER_LIB_DIR}/ggml-quants.c
${RNWHISPER_LIB_DIR}/whisper.cpp
${RNWHISPER_LIB_DIR}/rn-audioutils.cpp
${RNWHISPER_LIB_DIR}/rn-whisper.cpp
${CMAKE_SOURCE_DIR}/jni.cpp
)
Expand Down
80 changes: 0 additions & 80 deletions android/src/main/java/com/rnwhisper/AudioUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@

import android.util.Log;

import java.util.ArrayList;
import java.lang.StringBuilder;
import java.io.IOException;
import java.io.FileReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
Expand All @@ -19,82 +15,6 @@
public class AudioUtils {
private static final String NAME = "RNWhisperAudioUtils";

private static final int SAMPLE_RATE = 16000;

private static byte[] shortToByte(short[] shortInts) {
int j = 0;
int length = shortInts.length;
byte[] byteData = new byte[length * 2];
for (int i = 0; i < length; i++) {
byteData[j++] = (byte) (shortInts[i] >>> 8);
byteData[j++] = (byte) (shortInts[i] >>> 0);
}
return byteData;
}

public static byte[] concatShortBuffers(ArrayList<short[]> buffers) {
int totalLength = 0;
for (int i = 0; i < buffers.size(); i++) {
totalLength += buffers.get(i).length;
}
byte[] result = new byte[totalLength * 2];
int offset = 0;
for (int i = 0; i < buffers.size(); i++) {
byte[] bytes = shortToByte(buffers.get(i));
System.arraycopy(bytes, 0, result, offset, bytes.length);
offset += bytes.length;
}

return result;
}

private static byte[] removeTrailingZeros(byte[] audioData) {
int i = audioData.length - 1;
while (i >= 0 && audioData[i] == 0) {
--i;
}
byte[] newData = new byte[i + 1];
System.arraycopy(audioData, 0, newData, 0, i + 1);
return newData;
}

public static void saveWavFile(byte[] rawData, String audioOutputFile) throws IOException {
Log.d(NAME, "call saveWavFile");
rawData = removeTrailingZeros(rawData);
DataOutputStream output = null;
try {
output = new DataOutputStream(new FileOutputStream(audioOutputFile));
// WAVE header
// see http://ccrma.stanford.edu/courses/422/projects/WaveFormat/
output.writeBytes("RIFF"); // chunk id
output.writeInt(Integer.reverseBytes(36 + rawData.length)); // chunk size
output.writeBytes("WAVE"); // format
output.writeBytes("fmt "); // subchunk 1 id
output.writeInt(Integer.reverseBytes(16)); // subchunk 1 size
output.writeShort(Short.reverseBytes((short) 1)); // audio format (1 = PCM)
output.writeShort(Short.reverseBytes((short) 1)); // number of channels
output.writeInt(Integer.reverseBytes(SAMPLE_RATE)); // sample rate
output.writeInt(Integer.reverseBytes(SAMPLE_RATE * 2)); // byte rate
output.writeShort(Short.reverseBytes((short) 2)); // block align
output.writeShort(Short.reverseBytes((short) 16)); // bits per sample
output.writeBytes("data"); // subchunk 2 id
output.writeInt(Integer.reverseBytes(rawData.length)); // subchunk 2 size
// Audio data (conversion big endian -> little endian)
short[] shorts = new short[rawData.length / 2];
ByteBuffer.wrap(rawData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts);
ByteBuffer bytes = ByteBuffer.allocate(shorts.length * 2);
for (short s : shorts) {
bytes.putShort(s);
}
Log.d(NAME, "writing audio file: " + audioOutputFile);
output.write(bytes.array());
} finally {
if (output != null) {
output.close();
}
}
}

public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
Expand Down
17 changes: 2 additions & 15 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,8 @@ private boolean vad(ReadableMap options, int sliceIndex, int nSamples, int n) {
}

private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) {
String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null;
if (audioOutputPath != null) {
// TODO: Append in real time so we don't need to keep all slices & also reduce memory usage
Log.d(NAME, "Begin saving wav file to " + audioOutputPath);
// try {
// // TODO: cpp audio utils
// AudioUtils.saveWavFile(AudioUtils.concatShortBuffers(shortBufferSlices), audioOutputPath);
// } catch (IOException e) {
// Log.e(NAME, "Error saving wav file: " + e.getMessage());
// }
}
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
removeRealtimeTranscribeJob(jobId, context);
finishRealtimeTranscribeJob(jobId, context, sliceNSamples.stream().mapToInt(i -> i).toArray());
}

public int startRealtimeTranscribe(int jobId, ReadableMap options) {
Expand Down Expand Up @@ -123,8 +112,6 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) {

createRealtimeTranscribeJob(jobId, context, options);

String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null;

sliceNSamples = new ArrayList<Integer>();
sliceNSamples.add(0);

Expand Down Expand Up @@ -508,7 +495,7 @@ protected static native void createRealtimeTranscribeJob(
long context,
ReadableMap options
);
protected static native void removeRealtimeTranscribeJob(int job_id, long context);
protected static native void finishRealtimeTranscribeJob(int job_id, long context, int[] sliceNSamples);
protected static native boolean vadSimple(int job_id, int slice_index, int n_samples, int n);
protected static native void putPcmData(short[] buffer, int slice_index, int n_samples, int n);
protected static native int fullWithJob(
Expand Down
27 changes: 24 additions & 3 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,24 +328,45 @@ Java_com_rnwhisper_WhisperContext_createRealtimeTranscribeJob(
vad.vad_ms = readablemap::getInt(env, options, "vadMs", 2000);
vad.vad_thold = readablemap::getFloat(env, options, "vadThold", 0.6f);
vad.freq_thold = readablemap::getFloat(env, options, "vadFreqThold", 100.0f);

jstring audio_output_path = readablemap::getString(env, options, "audioOutputPath", nullptr);
std::string *audio_output_path_str = nullptr;
if (audio_output_path != nullptr) {
audio_output_path_str = new std::string(env->GetStringUTFChars(audio_output_path, nullptr));
env->ReleaseStringUTFChars(audio_output_path, audio_output_path_str->c_str());
}
job->set_realtime_params(
vad,
readablemap::getInt(env, options, "realtimeAudioSec", 0),
readablemap::getInt(env, options, "realtimeAudioSliceSec", 0)
readablemap::getInt(env, options, "realtimeAudioSliceSec", 0),
audio_output_path_str
);
}

JNIEXPORT void JNICALL
Java_com_rnwhisper_WhisperContext_removeRealtimeTranscribeJob(
Java_com_rnwhisper_WhisperContext_finishRealtimeTranscribeJob(
JNIEnv *env,
jobject thiz,
jint job_id,
jlong context_ptr
jlong context_ptr,
jintArray slice_n_samples
) {
UNUSED(env);
UNUSED(thiz);
UNUSED(context_ptr);

rnwhisper::job *job = rnwhisper::job_get(job_id);
if (job->audio_output_path != nullptr) {
std::vector<int> slice_n_samples_vec;
jint *slice_n_samples_arr = env->GetIntArrayElements(slice_n_samples, nullptr);
slice_n_samples_vec = std::vector<int>(slice_n_samples_arr, slice_n_samples_arr + env->GetArrayLength(slice_n_samples));
env->ReleaseIntArrayElements(slice_n_samples, slice_n_samples_arr, JNI_ABORT);

rnaudioutils::save_wav_file(
rnaudioutils::concat_short_buffers(job->pcm_slices, slice_n_samples_vec),
*job->audio_output_path
);
}
job->free_pcm_slices();
rnwhisper::job_remove(job_id);
}
Expand Down
65 changes: 65 additions & 0 deletions cpp/rn-audioutils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "rn-audioutils.h"

namespace rnaudioutils {

std::vector<uint8_t> concat_short_buffers(const std::vector<short*>& buffers, const std::vector<int>& slice_n_samples) {
std::vector<uint8_t> output_data;

for (size_t i = 0; i < buffers.size(); i++) {
int size = slice_n_samples[i]; // Number of shorts
short* slice = buffers[i];

// Copy each short as two bytes
for (int j = 0; j < size; j++) {
output_data.push_back(static_cast<uint8_t>(slice[j] & 0xFF)); // Lower byte
output_data.push_back(static_cast<uint8_t>((slice[j] >> 8) & 0xFF)); // Higher byte
}
}

return output_data;
}

std::vector<uint8_t> remove_trailing_zeros(const std::vector<uint8_t>& audio_data) {
auto last = std::find_if(audio_data.rbegin(), audio_data.rend(), [](uint8_t byte) { return byte != 0; });
return std::vector<uint8_t>(audio_data.begin(), last.base());
}

void save_wav_file(const std::vector<uint8_t>& raw, const std::string& file) {
std::vector<uint8_t> data = remove_trailing_zeros(raw);

std::ofstream output(file, std::ios::binary);

if (!output.is_open()) {
std::cerr << "Failed to open file for writing: " << file << std::endl;
return;
}

// WAVE header
output.write("RIFF", 4);
int32_t chunk_size = 36 + static_cast<int32_t>(data.size());
output.write(reinterpret_cast<char*>(&chunk_size), sizeof(chunk_size));
output.write("WAVE", 4);
output.write("fmt ", 4);
int32_t sub_chunk_size = 16;
output.write(reinterpret_cast<char*>(&sub_chunk_size), sizeof(sub_chunk_size));
short audio_format = 1;
output.write(reinterpret_cast<char*>(&audio_format), sizeof(audio_format));
short num_channels = 1;
output.write(reinterpret_cast<char*>(&num_channels), sizeof(num_channels));
int32_t sample_rate = WHISPER_SAMPLE_RATE;
output.write(reinterpret_cast<char*>(&sample_rate), sizeof(sample_rate));
int32_t byte_rate = WHISPER_SAMPLE_RATE * 2;
output.write(reinterpret_cast<char*>(&byte_rate), sizeof(byte_rate));
short block_align = 2;
output.write(reinterpret_cast<char*>(&block_align), sizeof(block_align));
short bits_per_sample = 16;
output.write(reinterpret_cast<char*>(&bits_per_sample), sizeof(bits_per_sample));
output.write("data", 4);
int32_t sub_chunk2_size = static_cast<int32_t>(data.size());
output.write(reinterpret_cast<char*>(&sub_chunk2_size), sizeof(sub_chunk2_size));
output.write(reinterpret_cast<const char*>(data.data()), data.size());

output.close();
}

} // namespace rnaudioutils
14 changes: 14 additions & 0 deletions cpp/rn-audioutils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <iostream>
#include <fstream>
#include <vector>
#include <cstdint>
#include <cstring>
#include <algorithm>
#include "whisper.h"

namespace rnaudioutils {

std::vector<uint8_t> concat_short_buffers(const std::vector<short*>& buffers, const std::vector<int>& slice_n_samples);
void save_wav_file(const std::vector<uint8_t>& raw, const std::string& file);

} // namespace rnaudioutils
8 changes: 7 additions & 1 deletion cpp/rn-whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,17 @@ job::~job() {
fprintf(stderr, "%s: job_id: %d\n", __func__, job_id);
}

void job::set_realtime_params(vad_params params, int sec, int slice_sec) {
void job::set_realtime_params(
vad_params params,
int sec,
int slice_sec,
std::string* output_path
) {
vad = params;
if (vad.vad_ms < 2000) vad.vad_ms = 2000;
audio_sec = sec > 0 ? sec : DEFAULT_MAX_AUDIO_SEC;
audio_slice_sec = slice_sec > 0 && slice_sec < audio_sec ? slice_sec : audio_sec;
audio_output_path = output_path;
}

bool job::vad_simple(int slice_index, int n_samples, int n) {
Expand Down
4 changes: 3 additions & 1 deletion cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <string>
#include <vector>
#include "whisper.h"
#include "rn-audioutils.h"

namespace rnwhisper {

Expand All @@ -29,8 +30,9 @@ struct job {
vad_params vad;
int audio_sec = 0;
int audio_slice_sec = 0;
std::string* audio_output_path = nullptr;
std::vector<short *> pcm_slices;
void set_realtime_params(vad_params vad, int audio_sec, int audio_slice_sec);
void set_realtime_params(vad_params vad, int audio_sec, int audio_slice_sec, std::string* audio_output_path);
bool vad_simple(int slice_index, int n_samples, int n);
void put_pcm_data(short* pcm, int slice_index, int n_samples, int n);
float* pcm_slice_to_f32(int slice_index, int size);
Expand Down
2 changes: 0 additions & 2 deletions ios/RNWhisperAudioUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

@interface RNWhisperAudioUtils : NSObject

+ (NSData *)concatShortBuffers:(NSMutableArray<NSValue *> *)buffers sliceNSamples:(NSMutableArray<NSNumber *> *)sliceNSamples;
+ (void)saveWavFile:(NSData *)rawData audioOutputFile:(NSString *)audioOutputFile;
+ (float *)decodeWaveFile:(NSString*)filePath count:(int *)count;

@end
56 changes: 0 additions & 56 deletions ios/RNWhisperAudioUtils.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,6 @@

@implementation RNWhisperAudioUtils

+ (NSData *)concatShortBuffers:(NSMutableArray<NSValue *> *)buffers sliceNSamples:(NSMutableArray<NSNumber *> *)sliceNSamples {
NSMutableData *outputData = [NSMutableData data];
for (int i = 0; i < buffers.count; i++) {
int size = [sliceNSamples objectAtIndex:i].intValue;
NSValue *buffer = [buffers objectAtIndex:i];
short *bufferPtr = buffer.pointerValue;
[outputData appendBytes:bufferPtr length:size * sizeof(short)];
}
return outputData;
}

+ (void)saveWavFile:(NSData *)rawData audioOutputFile:(NSString *)audioOutputFile {
NSMutableData *outputData = [NSMutableData data];

// WAVE header
[outputData appendData:[@"RIFF" dataUsingEncoding:NSUTF8StringEncoding]]; // chunk id
int chunkSize = CFSwapInt32HostToLittle(36 + rawData.length);
[outputData appendBytes:&chunkSize length:sizeof(chunkSize)];
[outputData appendData:[@"WAVE" dataUsingEncoding:NSUTF8StringEncoding]]; // format
[outputData appendData:[@"fmt " dataUsingEncoding:NSUTF8StringEncoding]]; // subchunk 1 id

int subchunk1Size = CFSwapInt32HostToLittle(16);
[outputData appendBytes:&subchunk1Size length:sizeof(subchunk1Size)];

short audioFormat = CFSwapInt16HostToLittle(1); // PCM
[outputData appendBytes:&audioFormat length:sizeof(audioFormat)];

short numChannels = CFSwapInt16HostToLittle(1); // mono
[outputData appendBytes:&numChannels length:sizeof(numChannels)];

int sampleRate = CFSwapInt32HostToLittle(WHISPER_SAMPLE_RATE);
[outputData appendBytes:&sampleRate length:sizeof(sampleRate)];

// (bitDepth * sampleRate * channels) >> 3
int byteRate = CFSwapInt32HostToLittle(WHISPER_SAMPLE_RATE * 1 * 16 / 8);
[outputData appendBytes:&byteRate length:sizeof(byteRate)];

// (bitDepth * channels) >> 3
short blockAlign = CFSwapInt16HostToLittle(16 / 8);
[outputData appendBytes:&blockAlign length:sizeof(blockAlign)];

// bitDepth
short bitsPerSample = CFSwapInt16HostToLittle(16);
[outputData appendBytes:&bitsPerSample length:sizeof(bitsPerSample)];

[outputData appendData:[@"data" dataUsingEncoding:NSUTF8StringEncoding]]; // subchunk 2 id
int subchunk2Size = CFSwapInt32HostToLittle((int)rawData.length);
[outputData appendBytes:&subchunk2Size length:sizeof(subchunk2Size)];

// Audio data
[outputData appendData:rawData];

// Save to file
[outputData writeToFile:audioOutputFile atomically:YES];
}

+ (float *)decodeWaveFile:(NSString*)filePath count:(int *)count {
NSURL *url = [NSURL fileURLWithPath:filePath];
NSData *fileData = [NSData dataWithContentsOfURL:url];
Expand Down
2 changes: 1 addition & 1 deletion ios/RNWhisperContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ typedef struct {
bool isCapturing;
bool isStoppedByAction;
int nSamplesTranscribing;
NSMutableArray<NSNumber *> *sliceNSamples;
std::vector<int> sliceNSamples;
bool isUseSlices;
int sliceIndex;
int transcribeSliceIndex;
Expand Down
Loading

0 comments on commit 6f95686

Please sign in to comment.