Skip to content

Commit

Permalink
feat: add tdrzEnable option
Browse files Browse the repository at this point in the history
  • Loading branch information
SooryR committed Sep 28, 2024
1 parent b70e391 commit 67b9758
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 7 deletions.
16 changes: 14 additions & 2 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class WhisperContext {
private boolean isTranscribing = false;
private Thread rootFullHandler = null;
private Thread fullHandler = null;
private ReadableMap options;

public WhisperContext(int id, ReactApplicationContext reactContext, long context) {
this.id = id;
Expand Down Expand Up @@ -103,7 +104,7 @@ public int startRealtimeTranscribe(int jobId, ReadableMap options) {
rewind();

this.jobId = jobId;

this.options = options;
int realtimeAudioSec = options.hasKey("realtimeAudioSec") ? options.getInt("realtimeAudioSec") : 0;
final int audioSec = realtimeAudioSec > 0 ? realtimeAudioSec : DEFAULT_MAX_AUDIO_SEC;
int realtimeAudioSliceSec = options.hasKey("realtimeAudioSliceSec") ? options.getInt("realtimeAudioSliceSec") : 0;
Expand Down Expand Up @@ -333,7 +334,7 @@ public WritableMap transcribeInputStream(int jobId, InputStream inputStream, Rea
throw new Exception("Context is already in capturing or transcribing");
}
rewind();

this.options = options;
this.jobId = jobId;
isTranscribing = true;
float[] audioData = AudioUtils.decodeWaveFile(inputStream);
Expand Down Expand Up @@ -368,8 +369,18 @@ private WritableMap getTextSegments(int start, int count) {

WritableMap data = Arguments.createMap();
WritableArray segments = Arguments.createArray();

// Check if tdrzEnable is enabled
boolean tdrzEnable = options != null && options.hasKey("tdrzEnable") && options.getBoolean("tdrzEnable");

for (int i = 0; i < count; i++) {
String text = getTextSegment(context, i);

// If tdrzEnable is enabled and speaker turn is detected
if (tdrzEnable && getTextSegmentSpeakerTurnNext(context, i)) {
text += " [SPEAKER_TURN]";
}

builder.append(text);

WritableMap segment = Arguments.createMap();
Expand Down Expand Up @@ -499,6 +510,7 @@ protected static native int fullWithNewJob(
protected static native String getTextSegment(long context, int index);
protected static native int getTextSegmentT0(long context, int index);
protected static native int getTextSegmentT1(long context, int index);
protected static native boolean getTextSegmentSpeakerTurnNext(long context, int index);

protected static native void createRealtimeTranscribeJob(
int job_id,
Expand Down
10 changes: 10 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ struct whisper_full_params createFullParams(JNIEnv *env, jobject options) {
params.translate = readablemap::getBool(env, options, "translate", false);
params.speed_up = readablemap::getBool(env, options, "speedUp", false);
params.token_timestamps = readablemap::getBool(env, options, "tokenTimestamps", false);
params.tdrz_enable = readablemap::getBool(env, options, "tdrzEnable", false);
params.offset_ms = 0;
params.no_context = true;
params.single_segment = false;
Expand Down Expand Up @@ -493,4 +494,13 @@ Java_com_rnwhisper_WhisperContext_freeContext(
whisper_free(context);
}

JNIEXPORT jboolean JNICALL
Java_com_rnwhisper_WhisperContext_getTextSegmentSpeakerTurnNext(
JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
UNUSED(env);
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
return whisper_full_get_segment_speaker_turn_next(context, index);
}

} // extern "C"
1 change: 1 addition & 0 deletions docs/API/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ ___
| `offset?` | `number` | Time offset in milliseconds |
| `prompt?` | `string` | Initial Prompt |
| `speedUp?` | `boolean` | Speed up audio by x2 (reduced accuracy) |
| `tdrzEnable?` | `boolean` | Enable tinydiarize (requires a tdrz model) |
| `temperature?` | `number` | Tnitial decoding temperature |
| `temperatureInc?` | `number` | - |
| `tokenTimestamps?` | `boolean` | Enable token-level timestamps |
Expand Down
28 changes: 23 additions & 5 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ - (OSStatus)transcribeRealtime:(int)jobId
struct rnwhisper_segments_callback_data {
void (^onNewSegments)(NSDictionary *);
int total_n_new;
bool tdrz_enable;
};

- (void)transcribeFile:(int)jobId
Expand Down Expand Up @@ -386,12 +387,18 @@ - (void)transcribeFile:(int)jobId
NSMutableArray *segments = [[NSMutableArray alloc] init];
for (int i = data->total_n_new - n_new; i < data->total_n_new; i++) {
const char * text_cur = whisper_full_get_segment_text(ctx, i);
text = [text stringByAppendingString:[NSString stringWithUTF8String:text_cur]];
NSMutableString *mutable_ns_text = [NSMutableString stringWithUTF8String:text_cur];

if (data->tdrz_enable && whisper_full_get_segment_speaker_turn_next(ctx, i)) {
[mutable_ns_text appendString:@" [SPEAKER_TURN]"];
}

text = [text stringByAppendingString:[NSString stringWithString:mutable_ns_text]];

const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
NSDictionary *segment = @{
@"text": [NSString stringWithUTF8String:text_cur],
@"text": [NSString stringWithString:mutable_ns_text],
@"t0": [NSNumber numberWithLongLong:t0],
@"t1": [NSNumber numberWithLongLong:t1]
};
Expand All @@ -409,7 +416,8 @@ - (void)transcribeFile:(int)jobId
};
struct rnwhisper_segments_callback_data user_data = {
.onNewSegments = onNewSegments,
.total_n_new = 0
.total_n_new = 0,
.tdrz_enable = params.tdrz_enable,
};
params.new_segment_callback_user_data = &user_data;
}
Expand Down Expand Up @@ -481,6 +489,7 @@ - (struct whisper_full_params)createParams:(NSDictionary *)options jobId:(int)jo
params.max_len = [options[@"maxLen"] intValue];
}
params.token_timestamps = options[@"tokenTimestamps"] != nil ? [options[@"tokenTimestamps"] boolValue] : false;
params.tdrz_enable = options[@"tdrzEnable"] != nil ? [options[@"tdrzEnable"] boolValue] : false;

if (options[@"bestOf"] != nil) {
params.greedy.best_of = [options[@"bestOf"] intValue];
Expand Down Expand Up @@ -530,12 +539,21 @@ - (NSMutableDictionary *)getTextSegments {
NSMutableArray *segments = [[NSMutableArray alloc] init];
for (int i = 0; i < n_segments; i++) {
const char * text_cur = whisper_full_get_segment_text(self->ctx, i);
text = [text stringByAppendingString:[NSString stringWithUTF8String:text_cur]];
NSMutableString *mutable_ns_text = [NSMutableString stringWithUTF8String:text_cur];

// Simplified condition
if (self->recordState.options[@"tdrzEnable"] &&
[self->recordState.options[@"tdrzEnable"] boolValue] &&
whisper_full_get_segment_speaker_turn_next(self->ctx, i)) {
[mutable_ns_text appendString:@" [SPEAKER_TURN]"];
}

text = [text stringByAppendingString:mutable_ns_text];

const int64_t t0 = whisper_full_get_segment_t0(self->ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(self->ctx, i);
NSDictionary *segment = @{
@"text": [NSString stringWithUTF8String:text_cur],
@"text": [NSString stringWithString:mutable_ns_text],
@"t0": [NSNumber numberWithLongLong:t0],
@"t1": [NSNumber numberWithLongLong:t1]
};
Expand Down
2 changes: 2 additions & 0 deletions src/NativeRNWhisper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export type TranscribeOptions = {
maxLen?: number,
/** Enable token-level timestamps */
tokenTimestamps?: boolean,
/** Enable token-level timestamps */
tdrzEnable?: boolean,
/** Word timestamp probability threshold */
wordThold?: number,
/** Time offset in milliseconds */
Expand Down

0 comments on commit 67b9758

Please sign in to comment.