From 61f01e7e3576ea114490713989eaf4a73a6c07c1 Mon Sep 17 00:00:00 2001 From: Jess Edmund Fan Date: Fri, 22 Sep 2023 17:35:44 -0700 Subject: [PATCH] feat(android): allow saving recorded audio as wav on startRealtimeTranscribe (#128) * feat: android startRealtimeTranscribe allow saving recorded audio as .wav This change allows use cases such as playing back audio while highlighting the current audio segment * style: remove uneeded changes to spacing bad prettier format * fix: remove forced .wav extension on written audio files Fix to comment * fix: update audioOutputPath comment --------- Co-authored-by: Jhen-Jie Hong --- .../java/com/rnwhisper/WhisperContext.java | 80 +++++++++++++++++++ src/index.ts | 6 ++ 2 files changed, 86 insertions(+) diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index 3b73e72..1c9964e 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -24,6 +24,8 @@ import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.PushbackInputStream; @@ -86,6 +88,80 @@ private void rewind() { fullHandler = null; } + public byte[] shortToByte(short[] shortInts) { + int j = 0; + int length = shortInts.length; + byte[] byteData = new byte[length * 2]; + for (int i = 0; i < length; i++) { + byteData[j++] = (byte) (shortInts[i] >>> 8); + byteData[j++] = (byte) (shortInts[i] >>> 0); + } + return byteData; + } + + public byte[] concatShortBuffers(ArrayList buffers) { + int totalLength = 0; + for (int i = 0; i < buffers.size(); i++) { + totalLength += buffers.get(i).length; + } + byte[] result = new byte[totalLength * 2]; + int offset = 0; + for (int i = 0; i < buffers.size(); i++) { + byte[] bytes = shortToByte(buffers.get(i)); + System.arraycopy(bytes, 0, result, offset, bytes.length); + offset += bytes.length; + } + + return result; + } + + public byte[] removeTrailingZeros(byte[] audioData) { + int i = audioData.length - 1; + while (i >= 0 && audioData[i] == 0) { + --i; + } + byte[] newData = new byte[i + 1]; + System.arraycopy(audioData, 0, newData, 0, i + 1); + return newData; + } + + private void saveWavFile(byte[] rawData, String audioOutputFile) throws IOException { + Log.d(NAME, "call saveWavFile"); + rawData = removeTrailingZeros(rawData); + DataOutputStream output = null; + try { + output = new DataOutputStream(new FileOutputStream(audioOutputFile)); + // WAVE header + // see http://ccrma.stanford.edu/courses/422/projects/WaveFormat/ + output.writeBytes("RIFF"); // chunk id + output.writeInt(Integer.reverseBytes(36 + rawData.length)); // chunk size + output.writeBytes("WAVE"); // format + output.writeBytes("fmt "); // subchunk 1 id + output.writeInt(Integer.reverseBytes(16)); // subchunk 1 size + output.writeShort(Short.reverseBytes((short) 1)); // audio format (1 = PCM) + output.writeShort(Short.reverseBytes((short) 1)); // number of channels + output.writeInt(Integer.reverseBytes(SAMPLE_RATE)); // sample rate + output.writeInt(Integer.reverseBytes(SAMPLE_RATE * 2)); // byte rate + output.writeShort(Short.reverseBytes((short) 2)); // block align + output.writeShort(Short.reverseBytes((short) 16)); // bits per sample + output.writeBytes("data"); // subchunk 2 id + output.writeInt(Integer.reverseBytes(rawData.length)); // subchunk 2 size + // Audio data (conversion big endian -> little endian) + short[] shorts = new short[rawData.length / 2]; + ByteBuffer.wrap(rawData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts); + ByteBuffer bytes = ByteBuffer.allocate(shorts.length * 2); + for (short s : shorts) { + bytes.putShort(s); + } + Log.d(NAME, "writing audio file: " + audioOutputFile); + output.write(bytes.array()); + } finally { + if (output != null) { + output.close(); + } + } + } + public int startRealtimeTranscribe(int jobId, ReadableMap options) { if (isCapturing || isTranscribing) { return -100; @@ -111,6 +187,8 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) { isUseSlices = audioSliceSec < audioSec; + String audioOutputPath = options.hasKey("audioOutputPath") ? options.getString("audioOutputPath") : null; + shortBufferSlices = new ArrayList(); shortBufferSlices.add(new short[audioSliceSec * SAMPLE_RATE]); sliceNSamples = new ArrayList(); @@ -183,6 +261,8 @@ public void run() { Log.e(NAME, "Error transcribing realtime: " + e.getMessage()); } } + Log.d(NAME, "Begin saving wav file to " + audioOutputPath); + saveWavFile(concatShortBuffers(shortBufferSlices), audioOutputPath); if (!isTranscribing) { emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap()); } diff --git a/src/index.ts b/src/index.ts index 2c291d9..bb126b4 100644 --- a/src/index.ts +++ b/src/index.ts @@ -58,6 +58,12 @@ export type TranscribeRealtimeOptions = TranscribeOptions & { * (Default: Equal to `realtimeMaxAudioSec`) */ realtimeAudioSliceSec?: number + /** + * Output path for audio file. If not set, the audio file will not be saved + * TODO: Support iOS + * (Default: Undefined) + */ + audioOutputPath?: string } export type TranscribeRealtimeEvent = {