Skip to content

Commit

Permalink
fix: avoid last result on abort realtime transcription
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 7, 2023
1 parent aa2effb commit 58c4c96
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 20 deletions.
6 changes: 3 additions & 3 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe

if (code == 0) {
payload.putMap("data", getTextSegments(0, getTextSegmentCount(context)));
} else {
} else if (code != -999) { // Not aborted
payload.putString("error", "Transcribe failed with code " + code);
}

Expand All @@ -297,7 +297,7 @@ private void fullTranscribeSamples(ReadableMap options, boolean skipCapturingChe
nSamplesTranscribing = 0;
}

boolean continueNeeded = !isCapturing && nSamplesTranscribing != nSamplesOfIndex;
boolean continueNeeded = !isCapturing && nSamplesTranscribing != nSamplesOfIndex && code != -999;

if (isStopped && !continueNeeded) {
payload.putBoolean("isCapturing", false);
Expand Down Expand Up @@ -386,7 +386,7 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea
int code = full(jobId, options, audioData, audioData.length);
isTranscribing = false;
this.jobId = -1;
if (code != 0) {
if (code != 0 && code != 999) {
throw new Exception("Failed to transcribe the file. Code: " + code);
}
WritableMap result = getTextSegments(0, getTextSegmentCount(context));
Expand Down
8 changes: 6 additions & 2 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,16 +297,17 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
}

// abort handlers
bool* abort_ptr = rn_whisper_assign_abort_map(job_id);
params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(job_id);
params.encoder_begin_callback_user_data = abort_ptr;
params.abort_callback = [](void * user_data) {
bool is_aborted = *(bool*)user_data;
return is_aborted;
};
params.abort_callback_user_data = rn_whisper_assign_abort_map(job_id);
params.abort_callback_user_data = abort_ptr;

if (callback_instance != nullptr) {
callback_context *cb_ctx = new callback_context;
Expand Down Expand Up @@ -344,6 +345,9 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
}
env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
env->ReleaseStringUTFChars(language, language_chars);
if (rn_whisper_transcribe_is_aborted(job_id)) {
code = -999;
}
rn_whisper_remove_abort_map(job_id);
return code;
}
Expand Down
2 changes: 1 addition & 1 deletion ios/RNWhisper.mm
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ - (NSArray *)supportedEvents {
});
}
onEnd: ^(int code) {
if (code != 0) {
if (code != 0 && code != 999) {
free(waveFile);
reject(@"whisper_cpp_error", [NSString stringWithFormat:@"Failed to transcribe the file. Code: %d", code], nil);
return;
Expand Down
8 changes: 6 additions & 2 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -556,16 +556,17 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId
}

// abort handler
bool *abort_ptr = rn_whisper_assign_abort_map(jobId);
params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
params.encoder_begin_callback_user_data = rn_whisper_assign_abort_map(jobId);
params.encoder_begin_callback_user_data = abort_ptr;
params.abort_callback = [](void * user_data) {
bool is_aborted = *(bool*)user_data;
return is_aborted;
};
params.abort_callback_user_data = rn_whisper_assign_abort_map(jobId);
params.abort_callback_user_data = abort_ptr;

return params;
}
Expand All @@ -578,6 +579,9 @@ - (int)fullTranscribe:(int)jobId
whisper_reset_timings(self->ctx);

int code = whisper_full(self->ctx, params, audioData, audioDataCount);
if (rn_whisper_transcribe_is_aborted(jobId)) {
code = -999;
}
rn_whisper_remove_abort_map(jobId);
// if (code == 0) {
// whisper_print_timings(self->ctx);
Expand Down
22 changes: 10 additions & 12 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -307,25 +307,23 @@ export class WhisperContext {
let tOffset: number = 0

const putSlice = (payload: TranscribeRealtimeNativePayload) => {
if (!payload.isUseSlices) return
if (!payload.isUseSlices || !payload.data) return
if (sliceIndex !== payload.sliceIndex) {
const { segments = [] } = slices[sliceIndex]?.data || {}
tOffset = segments[segments.length - 1]?.t1 || 0
}
;({ sliceIndex } = payload)
slices[sliceIndex] = {
...payload,
data: payload.data
? {
...payload.data,
segments:
payload.data.segments.map((segment) => ({
...segment,
t0: segment.t0 + tOffset,
t1: segment.t1 + tOffset,
})) || [],
}
: undefined,
data: {
...payload.data,
segments:
payload.data.segments.map((segment) => ({
...segment,
t0: segment.t0 + tOffset,
t1: segment.t1 + tOffset,
})) || [],
}
}
}

Expand Down

0 comments on commit 58c4c96

Please sign in to comment.