diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index bd658ddd..7a4022ab 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -55,6 +55,7 @@ public class WhisperContext { private boolean isTranscribing = false; private Thread rootFullHandler = null; private Thread fullHandler = null; + private ReadableMap options; public WhisperContext(int id, ReactApplicationContext reactContext, long context) { this.id = id; @@ -103,7 +104,7 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) { rewind(); this.jobId = jobId; - + this.options = options; int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0; final int audioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC; int realtimeAudioSliceSec = options.hasKey("realtimeAudioSliceSec") ? options.getInt("realtimeAudioSliceSec") : 0; @@ -333,7 +334,7 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea throw new Exception("Context is already in capturing or transcribing"); } rewind(); - + this.options = options; this.jobId = jobId; isTranscribing = true; float[] audioData = AudioUtils.decodeWaveFile(inputStream); @@ -368,8 +369,18 @@ private WritableMap getTextSegments(int start, int count) { WritableMap data = Arguments.createMap(); WritableArray segments = Arguments.createArray(); + + // Check if tdrzEnable is enabled + boolean tdrzEnable = options != null && options.hasKey("tdrzEnable") && options.getBoolean("tdrzEnable"); + for (int i = 0; i < count; i++) { String text = getTextSegment(context, i); + + // If tdrzEnable is enabled and speaker turn is detected + if (tdrzEnable && getTextSegmentSpeakerTurnNext(context, i)) { + text += " [SPEAKER_TURN]"; + } + builder.append(text); WritableMap segment = Arguments.createMap(); @@ -499,6 +510,7 @@ protected static native int fullWithNewJob( protected static native String getTextSegment(long context, int index); protected static native int getTextSegmentT0(long context, int index); protected static native int getTextSegmentT1(long context, int index); + protected static native boolean getTextSegmentSpeakerTurnNext(long context, int index); protected static native void createRealtimeTranscribeJob( int job_id, diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 7a1f3bf6..09cd3f9d 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -208,6 +208,7 @@ struct whisper_full_params createFullParams(JNIEnv *env, jobject options) { params.translate = readablemap::getBool(env, options, "translate", false); params.speed_up = readablemap::getBool(env, options, "speedUp", false); params.token_timestamps = readablemap::getBool(env, options, "tokenTimestamps", false); + params.tdrz_enable = readablemap::getBool(env, options, "tdrzEnable", false); params.offset_ms = 0; params.no_context = true; params.single_segment = false; @@ -493,4 +494,13 @@ Java_com_rnwhisper_WhisperContext_freeContext( whisper_free(context); } +JNIEXPORT jboolean JNICALL +Java_com_rnwhisper_WhisperContext_getTextSegmentSpeakerTurnNext( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(env); + UNUSED(thiz); + struct whisper_context *context = reinterpret_cast(context_ptr); + return whisper_full_get_segment_speaker_turn_next(context, index); +} + } // extern "C" diff --git a/docs/API/README.md b/docs/API/README.md index c8ecbb5b..61d6d25e 100644 --- a/docs/API/README.md +++ b/docs/API/README.md @@ -149,6 +149,7 @@ ___ | `offset?` | `number` | Time offset in milliseconds | | `prompt?` | `string` | Initial Prompt | | `speedUp?` | `boolean` | Speed up audio by x2 (reduced accuracy) | +| `tdrzEnable?` | `boolean` | Enable tinydiarize (requires a tdrz model) | | `temperature?` | `number` | Tnitial decoding temperature | | `temperatureInc?` | `number` | - | | `tokenTimestamps?` | `boolean` | Enable token-level timestamps | diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 5012def2..7cdcb3ae 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -353,6 +353,7 @@ - (OSStatus)transcribeRealtime:(int)jobId struct rnwhisper_segments_callback_data { void (^onNewSegments)(NSDictionary *); int total_n_new; + bool tdrz_enable; }; - (void)transcribeFile:(int)jobId @@ -386,12 +387,18 @@ - (void)transcribeFile:(int)jobId NSMutableArray *segments = [[NSMutableArray alloc] init]; for (int i = data->total_n_new - n_new; i < data->total_n_new; i++) { const char * text_cur = whisper_full_get_segment_text(ctx, i); - text = [text stringByAppendingString:[NSString stringWithUTF8String:text_cur]]; + NSMutableString *mutable_ns_text = [NSMutableString stringWithUTF8String:text_cur]; + + if (data->tdrz_enable && whisper_full_get_segment_speaker_turn_next(ctx, i)) { + [mutable_ns_text appendString:@" [SPEAKER_TURN]"]; + } + + text = [text stringByAppendingString:[NSString stringWithString:mutable_ns_text]]; const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); NSDictionary *segment = @{ - @"text": [NSString stringWithUTF8String:text_cur], + @"text": [NSString stringWithString:mutable_ns_text], @"t0": [NSNumber numberWithLongLong:t0], @"t1": [NSNumber numberWithLongLong:t1] }; @@ -409,7 +416,8 @@ - (void)transcribeFile:(int)jobId }; struct rnwhisper_segments_callback_data user_data = { .onNewSegments = onNewSegments, - .total_n_new = 0 + .total_n_new = 0, + .tdrz_enable = params.tdrz_enable, }; params.new_segment_callback_user_data = &user_data; } @@ -481,6 +489,7 @@ - (struct whisper_full_params)createParams:(NSDictionary *)options jobId:(int)jo params.max_len = [options[@"maxLen"] intValue]; } params.token_timestamps = options[@"tokenTimestamps"] != nil ? [options[@"tokenTimestamps"] boolValue] : false; + params.tdrz_enable = options[@"tdrzEnable"] != nil ? [options[@"tdrzEnable"] boolValue] : false; if (options[@"bestOf"] != nil) { params.greedy.best_of = [options[@"bestOf"] intValue]; @@ -530,12 +539,21 @@ - (NSMutableDictionary *)getTextSegments { NSMutableArray *segments = [[NSMutableArray alloc] init]; for (int i = 0; i < n_segments; i++) { const char * text_cur = whisper_full_get_segment_text(self->ctx, i); - text = [text stringByAppendingString:[NSString stringWithUTF8String:text_cur]]; + NSMutableString *mutable_ns_text = [NSMutableString stringWithUTF8String:text_cur]; + + // Simplified condition + if (self->recordState.options[@"tdrzEnable"] && + [self->recordState.options[@"tdrzEnable"] boolValue] && + whisper_full_get_segment_speaker_turn_next(self->ctx, i)) { + [mutable_ns_text appendString:@" [SPEAKER_TURN]"]; + } + + text = [text stringByAppendingString:mutable_ns_text]; const int64_t t0 = whisper_full_get_segment_t0(self->ctx, i); const int64_t t1 = whisper_full_get_segment_t1(self->ctx, i); NSDictionary *segment = @{ - @"text": [NSString stringWithUTF8String:text_cur], + @"text": [NSString stringWithString:mutable_ns_text], @"t0": [NSNumber numberWithLongLong:t0], @"t1": [NSNumber numberWithLongLong:t1] }; diff --git a/src/NativeRNWhisper.ts b/src/NativeRNWhisper.ts index 8290f62c..93e55bb6 100644 --- a/src/NativeRNWhisper.ts +++ b/src/NativeRNWhisper.ts @@ -15,6 +15,8 @@ export type TranscribeOptions = { maxLen?: number, /** Enable token-level timestamps */ tokenTimestamps?: boolean, + /** Enable token-level timestamps */ + tdrzEnable?: boolean, /** Word timestamp probability threshold */ wordThold?: number, /** Time offset in milliseconds */