diff --git a/cpp/rn-whisper.cpp b/cpp/rn-whisper.cpp index ab98be6..5cf36b8 100644 --- a/cpp/rn-whisper.cpp +++ b/cpp/rn-whisper.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include "rn-whisper.h" @@ -8,6 +9,74 @@ namespace rnwhisper { +std::string bench(struct whisper_context * ctx, int n_threads) { + const int n_mels = whisper_model_n_mels(ctx); + + if (int ret = whisper_set_mel(ctx, nullptr, 0, n_mels)) { + return "error: failed to set mel: " + std::to_string(ret); + } + // heat encoder + if (int ret = whisper_encode(ctx, 0, n_threads) != 0) { + return "error: failed to encode: " + std::to_string(ret); + } + + whisper_token tokens[512]; + memset(tokens, 0, sizeof(tokens)); + + // prompt heat + if (int ret = whisper_decode(ctx, tokens, 256, 0, n_threads) != 0) { + return "error: failed to decode: " + std::to_string(ret); + } + + // text-generation heat + if (int ret = whisper_decode(ctx, tokens, 1, 256, n_threads) != 0) { + return "error: failed to decode: " + std::to_string(ret); + } + + whisper_reset_timings(ctx); + + // actual run + if (int ret = whisper_encode(ctx, 0, n_threads) != 0) { + return "error: failed to encode: " + std::to_string(ret); + } + + // text-generation + for (int i = 0; i < 256; i++) { + if (int ret = whisper_decode(ctx, tokens, 1, i, n_threads) != 0) { + return "error: failed to decode: " + std::to_string(ret); + } + } + + // batched decoding + for (int i = 0; i < 64; i++) { + if (int ret = whisper_decode(ctx, tokens, 5, 0, n_threads) != 0) { + return "error: failed to decode: " + std::to_string(ret); + } + } + + // prompt processing + for (int i = 0; i < 16; i++) { + if (int ret = whisper_decode(ctx, tokens, 256, 0, n_threads) != 0) { + return "error: failed to decode: " + std::to_string(ret); + } + } + + const int64_t t_end_us = wsp_ggml_time_us(); + whisper_timings timings = whisper_get_timings(ctx); + return std::string("[") + + std::to_string(timings.load_us) + "," + + std::to_string(timings.t_start_us) + "," + + std::to_string(t_end_us) + "," + + std::to_string(timings.fail_p) + "," + + std::to_string(timings.fail_h) + "," + + std::to_string(timings.t_mel_us) + "," + + std::to_string(timings.n_sample) + "," + + std::to_string(timings.n_encode) + "," + + std::to_string(timings.n_decode) + "," + + std::to_string(timings.n_batchd) + "," + + std::to_string(timings.n_prompt) + "]"; +} + void high_pass_filter(std::vector & data, float cutoff, float sample_rate) { const float rc = 1.0f / (2.0f * M_PI * cutoff); const float dt = 1.0f / sample_rate; diff --git a/cpp/rn-whisper.h b/cpp/rn-whisper.h index 46adbb9..b25c881 100644 --- a/cpp/rn-whisper.h +++ b/cpp/rn-whisper.h @@ -9,6 +9,8 @@ namespace rnwhisper { +std::string bench(whisper_context * ctx, int n_threads); + struct vad_params { bool use_vad = false; float vad_thold = 0.6f; diff --git a/cpp/whisper.cpp b/cpp/whisper.cpp index 24bfd05..34e4660 100644 --- a/cpp/whisper.cpp +++ b/cpp/whisper.cpp @@ -4190,28 +4190,53 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) { return ctx->vocab.token_transcribe; } +struct whisper_timings whisper_get_timings(struct whisper_context * ctx) { + const int64_t t_end_us = wsp_ggml_time_us(); + if (ctx->state == nullptr) { + return { + .load_us = ctx->t_load_us, + .t_start_us = ctx->t_start_us, + .fail_p = 0, + .fail_h = 0, + .t_mel_us = 0, + .n_sample = 0, + .n_encode = 0, + .n_decode = 0, + .n_batchd = 0, + .n_prompt = 0, + }; + } + + return { + .load_us = ctx->t_load_us, + .t_start_us = ctx->t_start_us, + .fail_p = ctx->state->n_fail_p, + .fail_h = ctx->state->n_fail_h, + .t_mel_us = ctx->state->t_mel_us, + .n_sample = std::max(1, ctx->state->n_sample), + .n_encode = std::max(1, ctx->state->n_encode), + .n_decode = std::max(1, ctx->state->n_decode), + .n_batchd = std::max(1, ctx->state->n_batchd), + .n_prompt = std::max(1, ctx->state->n_prompt), + }; +} + void whisper_print_timings(struct whisper_context * ctx) { const int64_t t_end_us = wsp_ggml_time_us(); + const struct whisper_timings timings = whisper_get_timings(ctx); WHISPER_LOG_INFO("\n"); - WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings.load_us / 1000.0f); if (ctx->state != nullptr) { - - const int32_t n_sample = std::max(1, ctx->state->n_sample); - const int32_t n_encode = std::max(1, ctx->state->n_encode); - const int32_t n_decode = std::max(1, ctx->state->n_decode); - const int32_t n_batchd = std::max(1, ctx->state->n_batchd); - const int32_t n_prompt = std::max(1, ctx->state->n_prompt); - - WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); - WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); - WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); - WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); - WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); - WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd); - WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); - } - WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); + WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings.fail_p, timings.fail_h); + WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings.t_mel_us / 1000.0f); + WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_sample, timings.n_sample, 1e-3f * timings.n_sample / timings.n_sample); + WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_encode, timings.n_encode, 1e-3f * timings.n_encode / timings.n_encode); + WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_decode, timings.n_decode, 1e-3f * timings.n_decode / timings.n_decode); + WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_batchd, timings.n_batchd, 1e-3f * timings.n_batchd / timings.n_batchd); + WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_prompt, timings.n_prompt, 1e-3f * timings.n_prompt / timings.n_prompt); + } + WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings.t_start_us)/1000.0f); } void whisper_reset_timings(struct whisper_context * ctx) { diff --git a/cpp/whisper.h b/cpp/whisper.h index 5f84c22..21b481a 100644 --- a/cpp/whisper.h +++ b/cpp/whisper.h @@ -424,6 +424,20 @@ extern "C" { WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx); // Performance information from the default state. + struct whisper_timings { + int64_t load_us; + int64_t t_start_us; + int32_t fail_p; + int32_t fail_h; + int64_t t_mel_us; + int32_t n_sample; + int32_t n_encode; + int32_t n_decode; + int32_t n_batchd; + int32_t n_prompt; + }; + + WHISPER_API struct whisper_timings whisper_get_timings(struct whisper_context * ctx); WHISPER_API void whisper_print_timings(struct whisper_context * ctx); WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); diff --git a/example/src/Bench.tsx b/example/src/Bench.tsx index 066d835..24da55c 100644 --- a/example/src/Bench.tsx +++ b/example/src/Bench.tsx @@ -14,30 +14,31 @@ import contextOpts from './context-opts' const baseURL = 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/' const modelList = [ - { name: 'tiny', coreml: true }, + // TODO: Add coreml model download + { name: 'tiny' }, { name: 'tiny-q5_1' }, { name: 'tiny-q8_0' }, - // { name: 'base', coreml: true }, - // { name: 'base-q5_1' }, - // { name: 'base-q8_0' }, - // { name: 'small', coreml: true }, - // { name: 'small-q5_1' }, - // { name: 'small-q8_0' }, - // { name: 'medium', coreml: true }, - // { name: 'medium-q5_0' }, - // { name: 'medium-q8_0' }, - // { name: 'large-v1', coreml: true }, - // { name: 'large-v1-q5_0', }, - // { name: 'large-v1-q8_0', }, - // { name: 'large-v2', coreml: true }, - // { name: 'large-v2-q5_0' }, - // { name: 'large-v2-q8_0' }, - // { name: 'large-v3', coreml: true }, - // { name: 'large-v3-q5_0' }, - // { name: 'large-v3-q8_0' }, - // { name: 'large-v3-turbo', coreml: true }, - // { name: 'large-v3-turbo-q5_0' }, - // { name: 'large-v3-turbo-q8_0' }, + { name: 'base' }, + { name: 'base-q5_1' }, + { name: 'base-q8_0' }, + { name: 'small' }, + { name: 'small-q5_1' }, + { name: 'small-q8_0' }, + { name: 'medium' }, + { name: 'medium-q5_0' }, + { name: 'medium-q8_0' }, + { name: 'large-v1' }, + { name: 'large-v1-q5_0', }, + { name: 'large-v1-q8_0', }, + { name: 'large-v2' }, + { name: 'large-v2-q5_0' }, + { name: 'large-v2-q8_0' }, + { name: 'large-v3' }, + { name: 'large-v3-q5_0' }, + { name: 'large-v3-q8_0' }, + { name: 'large-v3-turbo' }, + { name: 'large-v3-turbo-q5_0' }, + { name: 'large-v3-turbo-q8_0' }, ] as const const modelNameMap = modelList.reduce((acc, model) => { @@ -105,6 +106,7 @@ const styles = StyleSheet.create({ top: 0, bottom: 0, opacity: 0.5, + width: '0%', }, logContainer: { backgroundColor: 'lightgray', @@ -128,7 +130,14 @@ const Model = (props: { onDownloadStarted: (modelName: string) => void onDownloaded: (modelName: string) => void }) => { - const { model, state, downloadMap, setDownloadMap, onDownloadStarted, onDownloaded } = props + const { + model, + state, + downloadMap, + setDownloadMap, + onDownloadStarted, + onDownloaded, + } = props const downloadRef = useRef(null) const [progress, setProgress] = useState(0) @@ -160,7 +169,6 @@ const Model = (props: { onDownloaded(model.name) return } - console.log('[Model] download', `${baseURL}${model.name}.bin?download=true`) const { jobId, promise } = RNFS.downloadFile({ fromUrl: `${baseURL}ggml-${model.name}.bin?download=true`, toFile: `${fileDir}/ggml-${model.name}.bin`, @@ -213,8 +221,6 @@ const Model = (props: { } export default function Bench() { - const whisperContextRef = useRef(null) - const whisperContext = whisperContextRef.current const [logs, setLogs] = useState([]) const [downloadMap, setDownloadMap] = useState>(modelNameMap) @@ -222,7 +228,6 @@ export default function Bench() { const downloadedModelsRef = useRef([]) - const log = useCallback((...messages: any[]) => { setLogs((prev) => [...prev, messages.join(' ')]) }, []) @@ -292,7 +297,32 @@ export default function Bench() { } ${downloadCount} models`} - {}}> + { + await Object.entries(downloadMap).reduce(async (acc, [modelName, downloadNeeded]) => { + if (!downloadNeeded) return acc + const filePath = `${fileDir}/ggml-${modelName}.bin` + if (!(await RNFS.exists(filePath))) { + log(`${modelName} not found, skipping`) + return acc + } + const ctx = await initWhisper({ + filePath, + useCoreMLIos: false, + useGpu: Platform.OS === 'ios', + useFlashAttn: Platform.OS === 'ios', + }) + try { + const result = await ctx.bench(-1) + log(result) + } finally { + await ctx.release() + } + return acc + }, Promise.resolve()) + }} + > Run benchmark @@ -303,9 +333,7 @@ export default function Bench() { ))} - { - setLogs([]) - }}> + setLogs([])}> Clear Logs 0 ? maxThreads : 0; + + const int max_threads = (int) [[NSProcessInfo processInfo] processorCount]; + // Use 2 threads by default on 4-core devices, 4 threads on more cores + const int default_n_threads = max_threads == 4 ? 2 : MIN(4, max_threads); + NSString *result = [NSString stringWithUTF8String:rnwhisper::bench(self->ctx, n_threads).c_str()]; + return result; +} + - (void)invalidate { [self stopCurrentTranscribe]; whisper_free(self->ctx); diff --git a/scripts/whisper.cpp.patch b/scripts/whisper.cpp.patch index b13d16e..69ad663 100644 --- a/scripts/whisper.cpp.patch +++ b/scripts/whisper.cpp.patch @@ -1,5 +1,5 @@ ---- whisper.cpp.orig 2024-11-03 12:39:40 -+++ whisper.cpp 2024-11-03 12:40:25 +--- whisper.cpp.orig 2024-11-07 13:09:13 ++++ whisper.cpp 2024-11-07 13:03:02 @@ -3388,8 +3388,10 @@ const size_t memory_size = aheads_masks_nbytes(state->aheads_masks); WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size); @@ -11,14 +11,14 @@ const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); -@@ -3404,6 +3406,7 @@ - #endif +@@ -3405,6 +3407,7 @@ } else { WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__); -+ } } ++ } #endif + state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); @@ -3558,6 +3561,7 @@ struct whisper_context_params whisper_context_default_params() { struct whisper_context_params result = { @@ -27,3 +27,74 @@ /*.flash_attn =*/ false, /*.gpu_device =*/ 0, +@@ -4185,29 +4189,54 @@ + whisper_token whisper_token_transcribe(struct whisper_context * ctx) { + return ctx->vocab.token_transcribe; + } ++ ++struct whisper_timings whisper_get_timings(struct whisper_context * ctx) { ++ const int64_t t_end_us = wsp_ggml_time_us(); ++ if (ctx->state == nullptr) { ++ return { ++ .load_us = ctx->t_load_us, ++ .t_start_us = ctx->t_start_us, ++ .fail_p = 0, ++ .fail_h = 0, ++ .t_mel_us = 0, ++ .n_sample = 0, ++ .n_encode = 0, ++ .n_decode = 0, ++ .n_batchd = 0, ++ .n_prompt = 0, ++ }; ++ } + ++ return { ++ .load_us = ctx->t_load_us, ++ .t_start_us = ctx->t_start_us, ++ .fail_p = ctx->state->n_fail_p, ++ .fail_h = ctx->state->n_fail_h, ++ .t_mel_us = ctx->state->t_mel_us, ++ .n_sample = std::max(1, ctx->state->n_sample), ++ .n_encode = std::max(1, ctx->state->n_encode), ++ .n_decode = std::max(1, ctx->state->n_decode), ++ .n_batchd = std::max(1, ctx->state->n_batchd), ++ .n_prompt = std::max(1, ctx->state->n_prompt), ++ }; ++} ++ + void whisper_print_timings(struct whisper_context * ctx) { + const int64_t t_end_us = wsp_ggml_time_us(); ++ const struct whisper_timings timings = whisper_get_timings(ctx); + + WHISPER_LOG_INFO("\n"); +- WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); ++ WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings.load_us / 1000.0f); + if (ctx->state != nullptr) { +- +- const int32_t n_sample = std::max(1, ctx->state->n_sample); +- const int32_t n_encode = std::max(1, ctx->state->n_encode); +- const int32_t n_decode = std::max(1, ctx->state->n_decode); +- const int32_t n_batchd = std::max(1, ctx->state->n_batchd); +- const int32_t n_prompt = std::max(1, ctx->state->n_prompt); +- +- WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); +- WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); +- WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); +- WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); +- WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); +- WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd); +- WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); ++ WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings.fail_p, timings.fail_h); ++ WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings.t_mel_us / 1000.0f); ++ WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_sample, timings.n_sample, 1e-3f * timings.n_sample / timings.n_sample); ++ WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_encode, timings.n_encode, 1e-3f * timings.n_encode / timings.n_encode); ++ WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_decode, timings.n_decode, 1e-3f * timings.n_decode / timings.n_decode); ++ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_batchd, timings.n_batchd, 1e-3f * timings.n_batchd / timings.n_batchd); ++ WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_prompt, timings.n_prompt, 1e-3f * timings.n_prompt / timings.n_prompt); + } +- WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); ++ WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings.t_start_us)/1000.0f); + } + + void whisper_reset_timings(struct whisper_context * ctx) { diff --git a/scripts/whisper.h.patch b/scripts/whisper.h.patch index c3620a6..a750eaa 100644 --- a/scripts/whisper.h.patch +++ b/scripts/whisper.h.patch @@ -1,5 +1,5 @@ ---- whisper.h.orig 2024-11-03 12:37:43 -+++ whisper.h 2024-11-03 12:38:27 +--- whisper.h.orig 2024-11-07 13:09:13 ++++ whisper.h 2024-11-07 13:01:44 @@ -114,6 +114,7 @@ struct whisper_context_params { @@ -8,3 +8,24 @@ bool flash_attn; int gpu_device; // CUDA device +@@ -423,6 +424,20 @@ + WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx); + + // Performance information from the default state. ++ struct whisper_timings { ++ int64_t load_us; ++ int64_t t_start_us; ++ int32_t fail_p; ++ int32_t fail_h; ++ int64_t t_mel_us; ++ int32_t n_sample; ++ int32_t n_encode; ++ int32_t n_decode; ++ int32_t n_batchd; ++ int32_t n_prompt; ++ }; ++ ++ WHISPER_API struct whisper_timings whisper_get_timings(struct whisper_context * ctx); + WHISPER_API void whisper_print_timings(struct whisper_context * ctx); + WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + diff --git a/src/NativeRNWhisper.ts b/src/NativeRNWhisper.ts index e72d120..ab9b04a 100644 --- a/src/NativeRNWhisper.ts +++ b/src/NativeRNWhisper.ts @@ -86,6 +86,8 @@ export interface Spec extends TurboModule { ): Promise; abortTranscribe(contextId: number, jobId: number): Promise; + bench(contextId: number, maxThreads: number): Promise; + // iOS specific getAudioSessionCurrentCategory: () => Promise<{ category: string, diff --git a/src/index.ts b/src/index.ts index 0a8f5cf..224b0a7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -433,6 +433,10 @@ export class WhisperContext { } } + async bench(maxThreads: number): Promise { + return RNWhisper.bench(this.id, maxThreads) + } + async release(): Promise { return RNWhisper.releaseContext(this.id) }