Skip to content

Commit

Permalink
feat(ios, cpp): add bench method
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 7, 2024
1 parent 492a6dc commit 2b7b8e7
Show file tree
Hide file tree
Showing 12 changed files with 321 additions and 55 deletions.
69 changes: 69 additions & 0 deletions cpp/rn-whisper.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cstdio>
#include <string>
#include <sstream>
#include <vector>
#include <unordered_map>
#include "rn-whisper.h"
Expand All @@ -8,6 +9,74 @@

namespace rnwhisper {

std::string bench(struct whisper_context * ctx, int n_threads) {
const int n_mels = whisper_model_n_mels(ctx);

if (int ret = whisper_set_mel(ctx, nullptr, 0, n_mels)) {
return "error: failed to set mel: " + std::to_string(ret);
}
// heat encoder
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
return "error: failed to encode: " + std::to_string(ret);
}

whisper_token tokens[512];
memset(tokens, 0, sizeof(tokens));

// prompt heat
if (int ret = whisper_decode(ctx, tokens, 256, 0, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}

// text-generation heat
if (int ret = whisper_decode(ctx, tokens, 1, 256, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}

whisper_reset_timings(ctx);

// actual run
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
return "error: failed to encode: " + std::to_string(ret);
}

// text-generation
for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}
}

// batched decoding
for (int i = 0; i < 64; i++) {
if (int ret = whisper_decode(ctx, tokens, 5, 0, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}
}

// prompt processing
for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}
}

const int64_t t_end_us = wsp_ggml_time_us();
whisper_timings timings = whisper_get_timings(ctx);
return std::string("[") +
std::to_string(timings.load_us) + "," +
std::to_string(timings.t_start_us) + "," +
std::to_string(t_end_us) + "," +
std::to_string(timings.fail_p) + "," +
std::to_string(timings.fail_h) + "," +
std::to_string(timings.t_mel_us) + "," +
std::to_string(timings.n_sample) + "," +
std::to_string(timings.n_encode) + "," +
std::to_string(timings.n_decode) + "," +
std::to_string(timings.n_batchd) + "," +
std::to_string(timings.n_prompt) + "]";
}

void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
Expand Down
2 changes: 2 additions & 0 deletions cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

namespace rnwhisper {

std::string bench(whisper_context * ctx, int n_threads);

struct vad_params {
bool use_vad = false;
float vad_thold = 0.6f;
Expand Down
59 changes: 42 additions & 17 deletions cpp/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4190,28 +4190,53 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
return ctx->vocab.token_transcribe;
}

struct whisper_timings whisper_get_timings(struct whisper_context * ctx) {
const int64_t t_end_us = wsp_ggml_time_us();
if (ctx->state == nullptr) {
return {
.load_us = ctx->t_load_us,
.t_start_us = ctx->t_start_us,
.fail_p = 0,
.fail_h = 0,
.t_mel_us = 0,
.n_sample = 0,
.n_encode = 0,
.n_decode = 0,
.n_batchd = 0,
.n_prompt = 0,
};
}

return {
.load_us = ctx->t_load_us,
.t_start_us = ctx->t_start_us,
.fail_p = ctx->state->n_fail_p,
.fail_h = ctx->state->n_fail_h,
.t_mel_us = ctx->state->t_mel_us,
.n_sample = std::max(1, ctx->state->n_sample),
.n_encode = std::max(1, ctx->state->n_encode),
.n_decode = std::max(1, ctx->state->n_decode),
.n_batchd = std::max(1, ctx->state->n_batchd),
.n_prompt = std::max(1, ctx->state->n_prompt),
};
}

void whisper_print_timings(struct whisper_context * ctx) {
const int64_t t_end_us = wsp_ggml_time_us();
const struct whisper_timings timings = whisper_get_timings(ctx);

WHISPER_LOG_INFO("\n");
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings.load_us / 1000.0f);
if (ctx->state != nullptr) {

const int32_t n_sample = std::max(1, ctx->state->n_sample);
const int32_t n_encode = std::max(1, ctx->state->n_encode);
const int32_t n_decode = std::max(1, ctx->state->n_decode);
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);

WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
}
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings.fail_p, timings.fail_h);
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings.t_mel_us / 1000.0f);
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_sample, timings.n_sample, 1e-3f * timings.n_sample / timings.n_sample);
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_encode, timings.n_encode, 1e-3f * timings.n_encode / timings.n_encode);
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_decode, timings.n_decode, 1e-3f * timings.n_decode / timings.n_decode);
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_batchd, timings.n_batchd, 1e-3f * timings.n_batchd / timings.n_batchd);
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings.n_prompt, timings.n_prompt, 1e-3f * timings.n_prompt / timings.n_prompt);
}
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings.t_start_us)/1000.0f);
}

void whisper_reset_timings(struct whisper_context * ctx) {
Expand Down
14 changes: 14 additions & 0 deletions cpp/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,20 @@ extern "C" {
WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);

// Performance information from the default state.
struct whisper_timings {
int64_t load_us;
int64_t t_start_us;
int32_t fail_p;
int32_t fail_h;
int64_t t_mel_us;
int32_t n_sample;
int32_t n_encode;
int32_t n_decode;
int32_t n_batchd;
int32_t n_prompt;
};

WHISPER_API struct whisper_timings whisper_get_timings(struct whisper_context * ctx);
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);

Expand Down
90 changes: 59 additions & 31 deletions example/src/Bench.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,31 @@ import contextOpts from './context-opts'

const baseURL = 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/'
const modelList = [
{ name: 'tiny', coreml: true },
// TODO: Add coreml model download
{ name: 'tiny' },
{ name: 'tiny-q5_1' },
{ name: 'tiny-q8_0' },
// { name: 'base', coreml: true },
// { name: 'base-q5_1' },
// { name: 'base-q8_0' },
// { name: 'small', coreml: true },
// { name: 'small-q5_1' },
// { name: 'small-q8_0' },
// { name: 'medium', coreml: true },
// { name: 'medium-q5_0' },
// { name: 'medium-q8_0' },
// { name: 'large-v1', coreml: true },
// { name: 'large-v1-q5_0', },
// { name: 'large-v1-q8_0', },
// { name: 'large-v2', coreml: true },
// { name: 'large-v2-q5_0' },
// { name: 'large-v2-q8_0' },
// { name: 'large-v3', coreml: true },
// { name: 'large-v3-q5_0' },
// { name: 'large-v3-q8_0' },
// { name: 'large-v3-turbo', coreml: true },
// { name: 'large-v3-turbo-q5_0' },
// { name: 'large-v3-turbo-q8_0' },
{ name: 'base' },
{ name: 'base-q5_1' },
{ name: 'base-q8_0' },
{ name: 'small' },
{ name: 'small-q5_1' },
{ name: 'small-q8_0' },
{ name: 'medium' },
{ name: 'medium-q5_0' },
{ name: 'medium-q8_0' },
{ name: 'large-v1' },
{ name: 'large-v1-q5_0', },
{ name: 'large-v1-q8_0', },
{ name: 'large-v2' },
{ name: 'large-v2-q5_0' },
{ name: 'large-v2-q8_0' },
{ name: 'large-v3' },
{ name: 'large-v3-q5_0' },
{ name: 'large-v3-q8_0' },
{ name: 'large-v3-turbo' },
{ name: 'large-v3-turbo-q5_0' },
{ name: 'large-v3-turbo-q8_0' },
] as const

const modelNameMap = modelList.reduce((acc, model) => {
Expand Down Expand Up @@ -105,6 +106,7 @@ const styles = StyleSheet.create({
top: 0,
bottom: 0,
opacity: 0.5,
width: '0%',
},
logContainer: {
backgroundColor: 'lightgray',
Expand All @@ -128,7 +130,14 @@ const Model = (props: {
onDownloadStarted: (modelName: string) => void
onDownloaded: (modelName: string) => void
}) => {
const { model, state, downloadMap, setDownloadMap, onDownloadStarted, onDownloaded } = props
const {
model,
state,
downloadMap,
setDownloadMap,
onDownloadStarted,
onDownloaded,
} = props

const downloadRef = useRef<number | null>(null)
const [progress, setProgress] = useState(0)
Expand Down Expand Up @@ -160,7 +169,6 @@ const Model = (props: {
onDownloaded(model.name)
return
}
console.log('[Model] download', `${baseURL}${model.name}.bin?download=true`)
const { jobId, promise } = RNFS.downloadFile({
fromUrl: `${baseURL}ggml-${model.name}.bin?download=true`,
toFile: `${fileDir}/ggml-${model.name}.bin`,
Expand Down Expand Up @@ -213,16 +221,13 @@ const Model = (props: {
}

export default function Bench() {
const whisperContextRef = useRef<WhisperContext | null>(null)
const whisperContext = whisperContextRef.current
const [logs, setLogs] = useState<string[]>([])
const [downloadMap, setDownloadMap] =
useState<Record<string, boolean>>(modelNameMap)
const [modelState, setModelState] = useState<'select' | 'download'>('select')

const downloadedModelsRef = useRef<string[]>([])


const log = useCallback((...messages: any[]) => {
setLogs((prev) => [...prev, messages.join(' ')])
}, [])
Expand Down Expand Up @@ -292,7 +297,32 @@ export default function Bench() {
} ${downloadCount} models`}
</Text>
</Pressable>
<Pressable style={styles.button} onPress={() => {}}>
<Pressable
style={styles.button}
onPress={async () => {
await Object.entries(downloadMap).reduce(async (acc, [modelName, downloadNeeded]) => {
if (!downloadNeeded) return acc
const filePath = `${fileDir}/ggml-${modelName}.bin`
if (!(await RNFS.exists(filePath))) {
log(`${modelName} not found, skipping`)
return acc
}
const ctx = await initWhisper({
filePath,
useCoreMLIos: false,
useGpu: Platform.OS === 'ios',
useFlashAttn: Platform.OS === 'ios',
})
try {
const result = await ctx.bench(-1)
log(result)
} finally {
await ctx.release()
}
return acc
}, Promise.resolve())
}}
>
<Text style={styles.buttonText}>Run benchmark</Text>
</Pressable>
<View style={styles.logContainer}>
Expand All @@ -303,9 +333,7 @@ export default function Bench() {
))}
</View>
<View style={styles.buttonContainer}>
<Pressable style={styles.button} onPress={() => {
setLogs([])
}}>
<Pressable style={styles.button} onPress={() => setLogs([])}>
<Text style={styles.buttonText}>Clear Logs</Text>
</Pressable>
<Pressable
Expand Down
19 changes: 19 additions & 0 deletions ios/RNWhisper.mm
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,25 @@ - (NSArray *)supportedEvents {
resolve(nil);
}

RCT_REMAP_METHOD(bench,
withContextId:(int)contextId
withMaxThreads:(int)maxThreads
withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)
{
RNWhisperContext *context = contexts[[NSNumber numberWithInt:contextId]];
if (context == nil) {
reject(@"whisper_error", @"Context not found", nil);
return;
}
if ([context isTranscribing]) {
reject(@"whisper_error", @"The context is transcribing", nil);
return;
}
NSString *result = [context bench:maxThreads];
resolve(result);
}

RCT_REMAP_METHOD(releaseContext,
withContextId:(int)contextId
withResolver:(RCTPromiseResolveBlock)resolve
Expand Down
1 change: 1 addition & 0 deletions ios/RNWhisperContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ typedef struct {
- (bool)isTranscribing;
- (bool)isStoppedByAction;
- (NSMutableDictionary *)getTextSegments;
- (NSString *)bench:(int)maxThreads;
- (void)invalidate;

@end
10 changes: 10 additions & 0 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,16 @@ - (NSMutableDictionary *)getTextSegments {
return result;
}

- (NSString *)bench:(int)maxThreads {
const int n_threads = maxThreads > 0 ? maxThreads : 0;

const int max_threads = (int) [[NSProcessInfo processInfo] processorCount];
// Use 2 threads by default on 4-core devices, 4 threads on more cores
const int default_n_threads = max_threads == 4 ? 2 : MIN(4, max_threads);
NSString *result = [NSString stringWithUTF8String:rnwhisper::bench(self->ctx, n_threads).c_str()];
return result;
}

- (void)invalidate {
[self stopCurrentTranscribe];
whisper_free(self->ctx);
Expand Down
Loading

0 comments on commit 2b7b8e7

Please sign in to comment.