Skip to content

Commit

Permalink
feat(android): support transcribeData & transcribeFile with base64
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 9, 2024
1 parent bf8ba4e commit 13137bf
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 48 deletions.
39 changes: 27 additions & 12 deletions android/src/main/java/com/rnwhisper/AudioUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,29 @@

import android.util.Log;

import java.io.IOException;
import java.io.FileReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.ShortBuffer;
import java.util.Base64;

import java.util.Arrays;

public class AudioUtils {
private static final String NAME = "RNWhisperAudioUtils";

public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
baos.write(buffer, 0, bytesRead);
}
ByteBuffer byteBuffer = ByteBuffer.wrap(baos.toByteArray());
private static float[] bufferToFloatArray(byte[] buffer, Boolean cutHeader) {
ByteBuffer byteBuffer = ByteBuffer.wrap(buffer);
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
byteBuffer.position(44);
ShortBuffer shortBuffer = byteBuffer.asShortBuffer();
short[] shortArray = new short[shortBuffer.limit()];
shortBuffer.get(shortArray);
if (cutHeader) {
shortArray = Arrays.copyOfRange(shortArray, 44, shortArray.length);
}
float[] floatArray = new float[shortArray.length];
for (int i = 0; i < shortArray.length; i++) {
floatArray[i] = ((float) shortArray[i]) / 32767.0f;
Expand All @@ -36,4 +33,22 @@ public static float[] decodeWaveFile(InputStream inputStream) throws IOException
}
return floatArray;
}
}

public static float[] decodeWaveFile(InputStream inputStream) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
baos.write(buffer, 0, bytesRead);
}
return bufferToFloatArray(baos.toByteArray(), true);
}

public static float[] decodeWaveData(String dataBase64) throws IOException {
return bufferToFloatArray(Base64.getDecoder().decode(dataBase64), true);
}

public static float[] decodePcmData(String dataBase64) {
return bufferToFloatArray(Base64.getDecoder().decode(dataBase64), false);
}
}
100 changes: 66 additions & 34 deletions android/src/main/java/com/rnwhisper/RNWhisper.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Random;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.PushbackInputStream;

public class RNWhisper implements LifecycleEventListener {
Expand Down Expand Up @@ -119,44 +120,16 @@ protected void onPostExecute(Integer id) {
tasks.put(task, "initContext");
}

public void transcribeFile(double id, double jobId, String filePath, ReadableMap options, Promise promise) {
final WhisperContext context = contexts.get((int) id);
if (context == null) {
promise.reject("Context not found");
return;
}
if (context.isCapturing()) {
promise.reject("The context is in realtime transcribe mode");
return;
}
if (context.isTranscribing()) {
promise.reject("Context is already transcribing");
return;
}
private AsyncTask transcribe(WhisperContext context, double jobId, final float[] audioData, final ReadableMap options, Promise promise) {
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
protected WritableMap doInBackground(Void... voids) {
try {
String waveFilePath = filePath;

if (filePath.startsWith("http://") || filePath.startsWith("https://")) {
waveFilePath = downloader.downloadFile(filePath);
}

int resId = getResourceIdentifier(waveFilePath);
if (resId > 0) {
return context.transcribeInputStream(
(int) jobId,
reactContext.getResources().openRawResource(resId),
options
);
}

return context.transcribeInputStream(
return context.transcribe(
(int) jobId,
new FileInputStream(new File(waveFilePath)),
audioData,
options
);
} catch (Exception e) {
Expand All @@ -175,7 +148,66 @@ protected void onPostExecute(WritableMap data) {
tasks.remove(this);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "transcribeFile-" + id);
return task;
}

public void transcribeFile(double id, double jobId, String filePathOrBase64, ReadableMap options, Promise promise) {
final WhisperContext context = contexts.get((int) id);
if (context == null) {
promise.reject("Context not found");
return;
}
if (context.isCapturing()) {
promise.reject("The context is in realtime transcribe mode");
return;
}
if (context.isTranscribing()) {
promise.reject("Context is already transcribing");
return;
}

String waveFilePath = filePathOrBase64;
try {
if (filePathOrBase64.startsWith("http://") || filePathOrBase64.startsWith("https://")) {
waveFilePath = downloader.downloadFile(filePathOrBase64);
}

float[] audioData;
int resId = getResourceIdentifier(waveFilePath);
if (resId > 0) {
audioData = AudioUtils.decodeWaveFile(reactContext.getResources().openRawResource(resId));
} else if (filePathOrBase64.startsWith("data:audio/wav;base64,")) {
audioData = AudioUtils.decodeWaveData(filePathOrBase64);
} else {
audioData = AudioUtils.decodeWaveFile(new FileInputStream(new File(waveFilePath)));
}

AsyncTask task = transcribe(context, jobId, audioData, options, promise);
tasks.put(task, "transcribeFile-" + id);
} catch (Exception e) {
promise.reject(e);
}
}

public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) {
final WhisperContext context = contexts.get((int) id);
if (context == null) {
promise.reject("Context not found");
return;
}
if (context.isCapturing()) {
promise.reject("The context is in realtime transcribe mode");
return;
}
if (context.isTranscribing()) {
promise.reject("Context is already transcribing");
return;
}

float[] audioData = AudioUtils.decodePcmData(dataBase64);
AsyncTask task = transcribe(context, jobId, audioData, options, promise);

tasks.put(task, "transcribeData-" + id);
}

public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) {
Expand Down Expand Up @@ -211,7 +243,7 @@ protected Void doInBackground(Void... voids) {
context.stopTranscribe((int) jobId);
AsyncTask completionTask = null;
for (AsyncTask task : tasks.keySet()) {
if (tasks.get(task).equals("transcribeFile-" + id)) {
if (tasks.get(task).equals("transcribeFile-" + id) || tasks.get(task).equals("transcribeData-" + id)) {
task.get();
break;
}
Expand Down Expand Up @@ -259,7 +291,7 @@ protected Void doInBackground(Void... voids) {
context.stopCurrentTranscribe();
AsyncTask completionTask = null;
for (AsyncTask task : tasks.keySet()) {
if (tasks.get(task).equals("transcribeFile-" + contextId)) {
if (tasks.get(task).equals("transcribeFile-" + contextId) || tasks.get(task).equals("transcribeData-" + contextId)) {
task.get();
break;
}
Expand Down
3 changes: 1 addition & 2 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ void onNewSegments(int nNew) {
}
}

public WritableMap transcribeInputStream(int jobId, InputStream inputStream, ReadableMap options) throws IOException, Exception {
public WritableMap transcribe(int jobId, float[] audioData, ReadableMap options) throws IOException, Exception {
if (isCapturing || isTranscribing) {
throw new Exception("Context is already in capturing or transcribing");
}
Expand All @@ -341,7 +341,6 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea
this.isTdrzEnable = options.hasKey("tdrzEnable") && options.getBoolean("tdrzEnable");

isTranscribing = true;
float[] audioData = AudioUtils.decodeWaveFile(inputStream);

boolean hasProgressCallback = options.hasKey("onProgress") && options.getBoolean("onProgress");
boolean hasNewSegmentsCallback = options.hasKey("onNewSegments") && options.getBoolean("onNewSegments");
Expand Down
5 changes: 5 additions & 0 deletions android/src/newarch/java/com/rnwhisper/RNWhisperModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public void transcribeFile(double id, double jobId, String filePath, ReadableMap
rnwhisper.transcribeFile(id, jobId, filePath, options, promise);
}

@ReactMethod
public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) {
rnwhisper.transcribeData(id, jobId, dataBase64, options, promise);
}

@ReactMethod
public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) {
rnwhisper.startRealtimeTranscribe(id, jobId, options, promise);
Expand Down
5 changes: 5 additions & 0 deletions android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public void transcribeFile(double id, double jobId, String filePath, ReadableMap
rnwhisper.transcribeFile(id, jobId, filePath, options, promise);
}

@ReactMethod
public void transcribeData(double id, double jobId, String dataBase64, ReadableMap options, Promise promise) {
rnwhisper.transcribeData(id, jobId, dataBase64, options, promise);
}

@ReactMethod
public void startRealtimeTranscribe(double id, double jobId, ReadableMap options, Promise promise) {
rnwhisper.startRealtimeTranscribe(id, jobId, options, promise);
Expand Down

0 comments on commit 13137bf

Please sign in to comment.