Skip to content

Commit

Permalink
feat(ios): store job in RNWhisperContext
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Dec 6, 2023
1 parent 8ba887d commit 6fb88ed
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 75 deletions.
99 changes: 55 additions & 44 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,28 +212,7 @@ Java_com_rnwhisper_WhisperContext_vadSimple(
return is_speech;
}

struct callback_context {
JNIEnv *env;
jobject callback_instance;
};

JNIEXPORT jint JNICALL
Java_com_rnwhisper_WhisperContext_fullTranscribe(
JNIEnv *env,
jobject thiz,
jint job_id,
jlong context_ptr,
jfloatArray audio_data,
jint audio_data_len,
jobject transcribe_params,
jobject callback_instance
) {
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);

LOGI("About to create params");

struct whisper_full_params createFullParams(JNIEnv *env, jobject options) {
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);

params.print_realtime = false;
Expand All @@ -244,40 +223,72 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
int n_threads = readablemap::getInt(env, transcribe_params, "maxThreads", default_n_threads);
int n_threads = readablemap::getInt(env, options, "maxThreads", default_n_threads);
params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
params.translate = readablemap::getBool(env, transcribe_params, "translate", false);
params.speed_up = readablemap::getBool(env, transcribe_params, "speedUp", false);
params.token_timestamps = readablemap::getBool(env, transcribe_params, "tokenTimestamps", false);
params.translate = readablemap::getBool(env, options, "translate", false);
params.speed_up = readablemap::getBool(env, options, "speedUp", false);
params.token_timestamps = readablemap::getBool(env, options, "tokenTimestamps", false);
params.offset_ms = 0;
params.no_context = true;
params.single_segment = false;

int beam_size = readablemap::getInt(env, transcribe_params, "beamSize", -1);
int beam_size = readablemap::getInt(env, options, "beamSize", -1);
if (beam_size > -1) {
params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
params.beam_search.beam_size = beam_size;
}
int best_of = readablemap::getInt(env, transcribe_params, "bestOf", -1);
int best_of = readablemap::getInt(env, options, "bestOf", -1);
if (best_of > -1) params.greedy.best_of = best_of;
int max_len = readablemap::getInt(env, transcribe_params, "maxLen", -1);
int max_len = readablemap::getInt(env, options, "maxLen", -1);
if (max_len > -1) params.max_len = max_len;
int max_context = readablemap::getInt(env, transcribe_params, "maxContext", -1);
int max_context = readablemap::getInt(env, options, "maxContext", -1);
if (max_context > -1) params.n_max_text_ctx = max_context;
int offset = readablemap::getInt(env, transcribe_params, "offset", -1);
int offset = readablemap::getInt(env, options, "offset", -1);
if (offset > -1) params.offset_ms = offset;
int duration = readablemap::getInt(env, transcribe_params, "duration", -1);
int duration = readablemap::getInt(env, options, "duration", -1);
if (duration > -1) params.duration_ms = duration;
int word_thold = readablemap::getInt(env, transcribe_params, "wordThold", -1);
int word_thold = readablemap::getInt(env, options, "wordThold", -1);
if (word_thold > -1) params.thold_pt = word_thold;
float temperature = readablemap::getFloat(env, transcribe_params, "temperature", -1);
float temperature = readablemap::getFloat(env, options, "temperature", -1);
if (temperature > -1) params.temperature = temperature;
float temperature_inc = readablemap::getFloat(env, transcribe_params, "temperatureInc", -1);
float temperature_inc = readablemap::getFloat(env, options, "temperatureInc", -1);
if (temperature_inc > -1) params.temperature_inc = temperature_inc;
jstring prompt = readablemap::getString(env, transcribe_params, "prompt", nullptr);
if (prompt != nullptr) params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
jstring language = readablemap::getString(env, transcribe_params, "language", nullptr);
if (language != nullptr) params.language = env->GetStringUTFChars(language, nullptr);
jstring prompt = readablemap::getString(env, options, "prompt", nullptr);
if (prompt != nullptr) {
params.initial_prompt = env->GetStringUTFChars(prompt, nullptr);
env->ReleaseStringUTFChars(prompt, params.initial_prompt);
}
jstring language = readablemap::getString(env, options, "language", nullptr);
if (language != nullptr) {
params.language = env->GetStringUTFChars(language, nullptr);
env->ReleaseStringUTFChars(language, params.language);
}
return params;
}

struct callback_context {
JNIEnv *env;
jobject callback_instance;
};

JNIEXPORT jint JNICALL
Java_com_rnwhisper_WhisperContext_fullTranscribe(
JNIEnv *env,
jobject thiz,
jint job_id,
jlong context_ptr,
jfloatArray audio_data,
jint audio_data_len,
jobject options,
jobject callback_instance
) {
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);

LOGI("About to create params");

whisper_full_params params = createFullParams(env, options);

if (callback_instance != nullptr) {
callback_context *cb_ctx = new callback_context;
Expand Down Expand Up @@ -305,7 +316,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
params.new_segment_callback_user_data = cb_ctx;
}

rnwhisper::job job = rnwhisper::job_new(job_id, params);
rnwhisper::job* job = rnwhisper::job_new(job_id, params);

LOGI("About to reset timings");
whisper_reset_timings(context);
Expand All @@ -316,14 +327,14 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
// whisper_print_timings(context);
}
env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
if (language != nullptr) env->ReleaseStringUTFChars(language, params.language);
if (prompt != nullptr) env->ReleaseStringUTFChars(prompt, params.initial_prompt);

if (job.is_aborted()) code = -999;

if (job->is_aborted()) code = -999;
rnwhisper::job_remove(job_id);
return code;
}

// TODO: full for realtimeTranscribe with job_id (need create job first)

JNIEXPORT void JNICALL
Java_com_rnwhisper_WhisperContext_abortTranscribe(
JNIEnv *env,
Expand Down
4 changes: 2 additions & 2 deletions cpp/rn-whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void job_abort_all() {
}
}

job job_new(int job_id, struct whisper_full_params params) {
job* job_new(int job_id, struct whisper_full_params params) {
job ctx;
ctx.job_id = job_id;
ctx.params = params;
Expand All @@ -44,7 +44,7 @@ job job_new(int job_id, struct whisper_full_params params) {
params.abort_callback_user_data = &ctx;

job_map[job_id] = ctx;
return ctx;
return &job_map[job_id];
}

void job_remove(int job_id) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct job {
};

void job_abort_all();
job job_new(int job_id, struct whisper_full_params params);
job* job_new(int job_id, struct whisper_full_params params);
void job_remove(int job_id);
job* job_get(int job_id);

Expand Down
3 changes: 2 additions & 1 deletion example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ export default function App() {
log('Start transcribing...')
const startTime = Date.now()
const { stop, promise } = whisperContext.transcribe(sampleFile, {
language: 'en',
language: 'zh',
prompt: 'HELLO WORLD',
maxLen: 1,
tokenTimestamps: true,
onProgress: (cur) => {
Expand Down
4 changes: 2 additions & 2 deletions ios/RNWhisperContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

typedef struct {
__unsafe_unretained id mSelf;

int jobId;
NSDictionary* options;

struct rnwhisper::job * job;

bool isTranscribing;
bool isRealtime;
bool isCapturing;
Expand Down
45 changes: 20 additions & 25 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ - (void)finishRealtimeTranscribe:(RNWhisperContextRecordState*) state result:(NS
audioOutputFile:state->audioOutputPath
];
}
state->transcribeHandler(state->jobId, @"end", result);
state->transcribeHandler(state->job->job_id, @"end", result);
rnwhisper::job_remove(state->job->job_id);
}

- (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state {
Expand All @@ -290,8 +291,9 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state {
audioBufferF32[i] = (float)audioBufferI16[i] / 32768.0f;
}
CFTimeInterval timeStart = CACurrentMediaTime();
struct whisper_full_params params = [state->mSelf getParams:state->options jobId:state->jobId];
int code = [state->mSelf fullTranscribe:state->jobId params:params audioData:audioBufferF32 audioDataCount:state->nSamplesTranscribing];

int code = [state->mSelf fullTranscribe:state->job audioData:audioBufferF32 audioDataCount:state->nSamplesTranscribing];

free(audioBufferF32);
CFTimeInterval timeEnd = CACurrentMediaTime();
const float timeRecording = (float) state->nSamplesTranscribing / (float) state->dataFormat.mSampleRate;
Expand Down Expand Up @@ -340,10 +342,10 @@ - (void)fullTranscribeSamples:(RNWhisperContextRecordState*) state {
[state->mSelf finishRealtimeTranscribe:state result:result];
} else if (code == 0) {
result[@"isCapturing"] = @(true);
state->transcribeHandler(state->jobId, @"transcribe", result);
state->transcribeHandler(state->job->job_id, @"transcribe", result);
} else {
result[@"isCapturing"] = @(true);
state->transcribeHandler(state->jobId, @"transcribe", result);
state->transcribeHandler(state->job->job_id, @"transcribe", result);
}

if (continueNeeded) {
Expand Down Expand Up @@ -371,8 +373,8 @@ - (OSStatus)transcribeRealtime:(int)jobId
onTranscribe:(void (^)(int, NSString *, NSDictionary *))onTranscribe
{
self->recordState.transcribeHandler = onTranscribe;
self->recordState.jobId = jobId;
[self prepareRealtime:options];
self->recordState.job = rnwhisper::job_new(jobId, [self createParams:options jobId:jobId]);

OSStatus status = AudioQueueNewInput(
&self->recordState.dataFormat,
Expand Down Expand Up @@ -413,9 +415,9 @@ - (void)transcribeFile:(int)jobId
dispatch_async(dQueue, ^{
self->recordState.isStoppedByAction = false;
self->recordState.isTranscribing = true;
self->recordState.jobId = jobId;

whisper_full_params params = [self getParams:options jobId:jobId];
whisper_full_params params = [self createParams:options jobId:jobId];

if (options[@"onProgress"] && [options[@"onProgress"] boolValue]) {
params.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
void (^onProgress)(int) = (__bridge void (^)(int))user_data;
Expand Down Expand Up @@ -460,8 +462,10 @@ - (void)transcribeFile:(int)jobId
};
params.new_segment_callback_user_data = &user_data;
}
int code = [self fullTranscribe:jobId params:params audioData:audioData audioDataCount:audioDataCount];
self->recordState.jobId = -1;

rnwhisper::job* job = rnwhisper::job_new(jobId, params);;
int code = [self fullTranscribe:job audioData:audioData audioDataCount:audioDataCount];
rnwhisper::job_remove(jobId);
self->recordState.isTranscribing = false;
onEnd(code);
});
Expand All @@ -476,8 +480,7 @@ - (void)stopAudio {
}

- (void)stopTranscribe:(int)jobId {
rnwhisper::job *job = rnwhisper::job_get(jobId);
if (job) job->abort();
if (self->recordState.job) self->recordState.job->abort();
if (self->recordState.isRealtime && self->recordState.isCapturing) {
[self stopAudio];
if (!self->recordState.isTranscribing) {
Expand All @@ -491,13 +494,11 @@ - (void)stopTranscribe:(int)jobId {
}

- (void)stopCurrentTranscribe {
if (!self->recordState.jobId) {
return;
}
[self stopTranscribe:self->recordState.jobId];
if (self->recordState.job == nullptr) return;
[self stopTranscribe:self->recordState.job->job_id];
}

- (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId {
- (struct whisper_full_params)createParams:(NSDictionary *)options jobId:(int)jobId {
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);

const int n_threads = options[@"maxThreads"] != nil ?
Expand Down Expand Up @@ -554,22 +555,16 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId
params.initial_prompt = [options[@"prompt"] UTF8String];
}

rnwhisper::job_new(jobId, params);
return params;
}

- (int)fullTranscribe:(int)jobId
params:(struct whisper_full_params)params
- (int)fullTranscribe:(rnwhisper::job *)job
audioData:(float *)audioData
audioDataCount:(int)audioDataCount
{
whisper_reset_timings(self->ctx);

int code = whisper_full(self->ctx, params, audioData, audioDataCount);

rnwhisper::job* job = rnwhisper::job_get(jobId);
int code = whisper_full(self->ctx, job->params, audioData, audioDataCount);
if (job && job->is_aborted()) code = -999;
rnwhisper::job_remove(jobId);
// if (code == 0) {
// whisper_print_timings(self->ctx);
// }
Expand Down

0 comments on commit 6fb88ed

Please sign in to comment.