Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cpp): unify some platform code (audio slices, utils, ...) #166

Merged
merged 20 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ set(
${RNWHISPER_LIB_DIR}/ggml-backend.c
${RNWHISPER_LIB_DIR}/ggml-quants.c
${RNWHISPER_LIB_DIR}/whisper.cpp
${RNWHISPER_LIB_DIR}/rn-audioutils.cpp
${RNWHISPER_LIB_DIR}/rn-whisper.cpp
${CMAKE_SOURCE_DIR}/jni.cpp
)
Expand All @@ -33,6 +34,10 @@ function(build_library target_name)
target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4)
endif ()

if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
target_compile_options(${target_name} PRIVATE -DRNWHISPER_ANDROID_ENABLE_LOGGING)
endif ()

# NOTE: If you want to debug the native code, you can uncomment if and endif
# if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")

Expand Down
80 changes: 0 additions & 80 deletions android/src/main/java/com/rnwhisper/AudioUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@

import android.util.Log;

import java.util.ArrayList;
import java.lang.StringBuilder;
import java.io.IOException;
import java.io.FileReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
Expand All @@ -19,82 +15,6 @@
public class AudioUtils {
private static final String NAME = "RNWhisperAudioUtils";

private static final int SAMPLE_RATE = 16000;

private static byte[] shortToByte(short[] shortInts) {
int j = 0;
int length = shortInts.length;
byte[] byteData = new byte[length * 2];
for (int i = 0; i < length; i++) {
byteData[j++] = (byte) (shortInts[i] >>> 8);
byteData[j++] = (byte) (shortInts[i] >>> 0);
}
return byteData;
}

public static byte[] concatShortBuffers(ArrayList<short[]> buffers) {
int totalLength = 0;
for (int i = 0; i < buffers.size(); i++) {
totalLength += buffers.get(i).length;
}
byte[] result = new byte[totalLength * 2];
int offset = 0;
for (int i = 0; i < buffers.size(); i++) {
byte[] bytes = shortToByte(buffers.get(i));
System.arraycopy(bytes, 0, result, offset, bytes.length);
offset += bytes.length;
}

return result;
}

private static byte[] removeTrailingZeros(byte[] audioData) {
int i = audioData.length - 1;
while (i >= 0 && audioData[i] == 0) {
--i;
}
byte[] newData = new byte[i + 1];
System.arraycopy(audioData, 0, newData, 0, i + 1);
return newData;
}

public static void saveWavFile(byte[] rawData, String audioOutputFile) throws IOException {
Log.d(NAME, "call saveWavFile");
rawData = removeTrailingZeros(rawData);
DataOutputStream output = null;
try {
output = new DataOutputStream(new FileOutputStream(audioOutputFile));
// WAVE header
// see http://ccrma.stanford.edu/courses/422/projects/WaveFormat/
output.writeBytes("RIFF"); // chunk id
output.writeInt(Integer.reverseBytes(36 + rawData.length)); // chunk size
output.writeBytes("WAVE"); // format
output.writeBytes("fmt "); // subchunk 1 id
output.writeInt(Integer.reverseBytes(16)); // subchunk 1 size
output.writeShort(Short.reverseBytes((short) 1)); // audio format (1 = PCM)
output.writeShort(Short.reverseBytes((short) 1)); // number of channels
output.writeInt(Integer.reverseBytes(SAMPLE_RATE)); // sample rate
output.writeInt(Integer.reverseBytes(SAMPLE_RATE * 2)); // byte rate
output.writeShort(Short.reverseBytes((short) 2)); // block align
output.writeShort(Short.reverseBytes((short) 16)); // bits per sample
output.writeBytes("data"); // subchunk 2 id
output.writeInt(Integer.reverseBytes(rawData.length)); // subchunk 2 size
// Audio data (conversion big endian -> little endian)
short[] shorts = new short[rawData.length / 2];
ByteBuffer.wrap(rawData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts);
ByteBuffer bytes = ByteBuffer.allocate(shorts.length * 2);
for (short s : shorts) {
bytes.putShort(s);
}
Log.d(NAME, "writing audio file: " + audioOutputFile);
output.write(bytes.array());
} finally {
if (output != null) {
output.close();
}
}
}

public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
Expand Down
134 changes: 48 additions & 86 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ public class WhisperContext {
private AudioRecord recorder = null;
private int bufferSize;
private int nSamplesTranscribing = 0;
private ArrayList<short[]> shortBufferSlices;
// Remember number of samples in each slice
private ArrayList<Integer> sliceNSamples;
// Current buffer slice index
Expand All @@ -66,7 +65,6 @@ public WhisperContext(int id, ReactApplicationContext reactContext, long context
}

private void rewind() {
shortBufferSlices = null;
sliceNSamples = null;
sliceIndex = 0;
transcribeSliceIndex = 0;
Expand All @@ -79,41 +77,14 @@ private void rewind() {
fullHandler = null;
}

private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) {
boolean isSpeech = true;
if (!isTranscribing && options.hasKey("useVad") && options.getBoolean("useVad")) {
int vadMs = options.hasKey("vadMs") ? options.getInt("vadMs") : 2000;
if (vadMs < 2000) vadMs = 2000;
int sampleSize = (int) (SAMPLE_RATE * vadMs / 1000);
if (nSamples + n > sampleSize) {
int start = nSamples + n - sampleSize;
float[] audioData = new float[sampleSize];
for (int i = 0; i < sampleSize; i++) {
audioData[i] = shortBuffer[i + start] / 32768.0f;
}
float vadThold = options.hasKey("vadThold") ? (float) options.getDouble("vadThold") : 0.6f;
float vadFreqThold = options.hasKey("vadFreqThold") ? (float) options.getDouble("vadFreqThold") : 0.6f;
isSpeech = vadSimple(audioData, sampleSize, vadThold, vadFreqThold);
} else {
isSpeech = false;
}
}
return isSpeech;
private boolean vad(int sliceIndex, int nSamples, int n) {
if (isTranscribing) return true;
return vadSimple(jobId, sliceIndex, nSamples, n);
}

private void finishRealtimeTranscribe(ReadableMap options, WritableMap result) {
String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null;
if (audioOutputPath != null) {
// TODO: Append in real time so we don't need to keep all slices & also reduce memory usage
Log.d(NAME, "Begin saving wav file to " + audioOutputPath);
try {
AudioUtils.saveWavFile(AudioUtils.concatShortBuffers(shortBufferSlices), audioOutputPath);
} catch (IOException e) {
Log.e(NAME, "Error saving wav file: " + e.getMessage());
}
}

private void finishRealtimeTranscribe(WritableMap result) {
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
finishRealtimeTranscribeJob(jobId, context, sliceNSamples.stream().mapToInt(i -> i).toArray());
}

public int startRealtimeTranscribe(int jobId, ReadableMap options) {
Expand All @@ -135,16 +106,12 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) {

int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0;
final int audioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC;

int realtimeAudioSliceSec = options.hasKey("realtimeAudioSliceSec") ? options.getInt("realtimeAudioSliceSec") : 0;
final int audioSliceSec = realtimeAudioSliceSec > 0 && realtimeAudioSliceSec < audioSec ? realtimeAudioSliceSec : audioSec;

isUseSlices = audioSliceSec < audioSec;

String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null;
createRealtimeTranscribeJob(jobId, context, options);

shortBufferSlices = new ArrayList<short[]>();
shortBufferSlices.add(new short[audioSliceSec * SAMPLE_RATE]);
sliceNSamples = new ArrayList<Integer>();
sliceNSamples.add(0);

Expand Down Expand Up @@ -175,37 +142,29 @@ public void run() {
nSamples == nSamplesTranscribing &&
sliceIndex == transcribeSliceIndex
) {
finishRealtimeTranscribe(options, Arguments.createMap());
finishRealtimeTranscribe(Arguments.createMap());
} else if (!isTranscribing) {
short[] shortBuffer = shortBufferSlices.get(sliceIndex);
boolean isSpeech = vad(options, shortBuffer, nSamples, 0);
if (!isSpeech) {
finishRealtimeTranscribe(options, Arguments.createMap());
if (!vad(sliceIndex, nSamples, 0)) {
finishRealtimeTranscribe(Arguments.createMap());
break;
}
isTranscribing = true;
fullTranscribeSamples(options, true);
fullTranscribeSamples(true);
}
break;
}

// Append to buffer
short[] shortBuffer = shortBufferSlices.get(sliceIndex);
if (nSamples + n > audioSliceSec * SAMPLE_RATE) {
Log.d(NAME, "next slice");

sliceIndex++;
nSamples = 0;
shortBuffer = new short[audioSliceSec * SAMPLE_RATE];
shortBufferSlices.add(shortBuffer);
sliceNSamples.add(0);
}
putPcmData(jobId, buffer, sliceIndex, nSamples, n);

for (int i = 0; i < n; i++) {
shortBuffer[nSamples + i] = buffer[i];
}

boolean isSpeech = vad(options, shortBuffer, nSamples, n);
boolean isSpeech = vad(sliceIndex, nSamples, n);

nSamples += n;
sliceNSamples.set(sliceIndex, nSamples);
Expand All @@ -217,7 +176,7 @@ public void run() {
fullHandler = new Thread(new Runnable() {
@Override
public void run() {
fullTranscribeSamples(options, false);
fullTranscribeSamples(false);
}
});
fullHandler.start();
Expand All @@ -228,7 +187,7 @@ public void run() {
}

if (!isTranscribing) {
finishRealtimeTranscribe(options, Arguments.createMap());
finishRealtimeTranscribe(Arguments.createMap());
}
if (fullHandler != null) {
fullHandler.join(); // Wait for full transcribe to finish
Expand All @@ -246,26 +205,16 @@ public void run() {
return state;
}

private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingCheck) {
private void fullTranscribeSamples(boolean skipCapturingCheck) {
int nSamplesOfIndex = sliceNSamples.get(transcribeSliceIndex);

if (!isCapturing && !skipCapturingCheck) return;

short[] shortBuffer = shortBufferSlices.get(transcribeSliceIndex);
int nSamples = sliceNSamples.get(transcribeSliceIndex);

nSamplesTranscribing = nSamplesOfIndex;

// convert I16 to F32
float[] nSamplesBuffer32 = new float[nSamplesTranscribing];
for (int i = 0; i < nSamplesTranscribing; i++) {
nSamplesBuffer32[i] = shortBuffer[i] / 32768.0f;
}

Log.d(NAME, "Start transcribing realtime: " + nSamplesTranscribing);

int timeStart = (int) System.currentTimeMillis();
int code = full(jobId, options, nSamplesBuffer32, nSamplesTranscribing);
int code = fullWithJob(jobId, context, transcribeSliceIndex, nSamplesTranscribing);
int timeEnd = (int) System.currentTimeMillis();
int timeRecording = (int) (nSamplesTranscribing / SAMPLE_RATE * 1000);

Expand Down Expand Up @@ -302,7 +251,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe
if (isStopped && !continueNeeded) {
payload.putBoolean("isCapturing", false);
payload.putBoolean("isStoppedByAction", isStoppedByAction);
finishRealtimeTranscribe(options, payload);
finishRealtimeTranscribe(payload);
} else if (code == 0) {
payload.putBoolean("isCapturing", true);
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribe", payload);
Expand All @@ -313,7 +262,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe

if (continueNeeded) {
// If no more capturing, continue transcribing until all slices are transcribed
fullTranscribeSamples(options, true);
fullTranscribeSamples(true);
} else if (isStopped) {
// No next, cleanup
rewind();
Expand Down Expand Up @@ -383,32 +332,30 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea
this.jobId = jobId;
isTranscribing = true;
float[] audioData = AudioUtils.decodeWaveFile(inputStream);
int code = full(jobId, options, audioData, audioData.length);
isTranscribing = false;
this.jobId = -1;
if (code != 0 && code != 999) {
throw new Exception("Failed to transcribe the file. Code: " + code);
}
WritableMap result = getTextSegments(0, getTextSegmentCount(context));
result.putBoolean("isAborted", isStoppedByAction);
return result;
}

private int full(int jobId, ReadableMap options, float[] audioData, int audioDataLen) {
boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress");
boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments");
return fullTranscribe(
int code = fullWithNewJob(
jobId,
context,
// float[] audio_data,
audioData,
// jint audio_data_len,
audioDataLen,
audioData.length,
// ReadableMap options,
options,
// Callback callback
hasProgressCallback || hasNewSegmentsCallback ? new Callback(this, hasProgressCallback, hasNewSegmentsCallback) : null
);

isTranscribing = false;
this.jobId = -1;
if (code != 0 && code != 999) {
throw new Exception("Failed to transcribe the file. Code: " + code);
}
WritableMap result = getTextSegments(0, getTextSegmentCount(context));
result.putBoolean("isAborted", isStoppedByAction);
return result;
}

private WritableMap getTextSegments(int start, int count) {
Expand Down Expand Up @@ -527,12 +474,13 @@ private static String cpuInfo() {
}
}


// JNI methods
protected static native long initContext(String modelPath);
protected static native long initContextWithAsset(AssetManager assetManager, String modelPath);
protected static native long initContextWithInputStream(PushbackInputStream inputStream);
protected static native boolean vadSimple(float[] audio_data, int audio_data_len, float vad_thold, float vad_freq_thold);
protected static native int fullTranscribe(
protected static native void freeContext(long contextPtr);

protected static native int fullWithNewJob(
int job_id,
long context,
float[] audio_data,
Expand All @@ -546,5 +494,19 @@ protected static native int fullTranscribe(
protected static native String getTextSegment(long context, int index);
protected static native int getTextSegmentT0(long context, int index);
protected static native int getTextSegmentT1(long context, int index);
protected static native void freeContext(long contextPtr);

protected static native void createRealtimeTranscribeJob(
int job_id,
long context,
ReadableMap options
);
protected static native void finishRealtimeTranscribeJob(int job_id, long context, int[] sliceNSamples);
protected static native boolean vadSimple(int job_id, int slice_index, int n_samples, int n);
protected static native void putPcmData(int job_id, short[] buffer, int slice_index, int n_samples, int n);
protected static native int fullWithJob(
int job_id,
long context,
int slice_index,
int n_samples
);
}
Loading
Loading