From 492a6dc9f9629287924f5ecef1f7f047222aa229 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Thu, 7 Nov 2024 10:24:14 +0800 Subject: [PATCH] feat(example): download models --- example/src/Bench.tsx | 265 +++++++++++++++++++++++++++++++++++------- 1 file changed, 224 insertions(+), 41 deletions(-) diff --git a/example/src/Bench.tsx b/example/src/Bench.tsx index 3a69fa5..066d835 100644 --- a/example/src/Bench.tsx +++ b/example/src/Bench.tsx @@ -1,4 +1,4 @@ -import React, { useCallback, useRef, useState } from 'react' +import React, { useCallback, useEffect, useRef, useState } from 'react' import { StyleSheet, ScrollView, @@ -17,27 +17,27 @@ const modelList = [ { name: 'tiny', coreml: true }, { name: 'tiny-q5_1' }, { name: 'tiny-q8_0' }, - { name: 'base', coreml: true }, - { name: 'base-q5_1' }, - { name: 'base-q8_0' }, - { name: 'small', coreml: true }, - { name: 'small-q5_1' }, - { name: 'small-q8_0' }, - { name: 'medium', coreml: true }, - { name: 'medium-q5_0' }, - { name: 'medium-q8_0' }, + // { name: 'base', coreml: true }, + // { name: 'base-q5_1' }, + // { name: 'base-q8_0' }, + // { name: 'small', coreml: true }, + // { name: 'small-q5_1' }, + // { name: 'small-q8_0' }, + // { name: 'medium', coreml: true }, + // { name: 'medium-q5_0' }, + // { name: 'medium-q8_0' }, // { name: 'large-v1', coreml: true }, // { name: 'large-v1-q5_0', }, // { name: 'large-v1-q8_0', }, - { name: 'large-v2', coreml: true }, - { name: 'large-v2-q5_0' }, - { name: 'large-v2-q8_0' }, - { name: 'large-v3', coreml: true }, - { name: 'large-v3-q5_0' }, - { name: 'large-v3-q8_0' }, - { name: 'large-v3-turbo', coreml: true }, - { name: 'large-v3-turbo-q5_0' }, - { name: 'large-v3-turbo-q8_0' }, + // { name: 'large-v2', coreml: true }, + // { name: 'large-v2-q5_0' }, + // { name: 'large-v2-q8_0' }, + // { name: 'large-v3', coreml: true }, + // { name: 'large-v3-q5_0' }, + // { name: 'large-v3-q8_0' }, + // { name: 'large-v3-turbo', coreml: true }, + // { name: 'large-v3-turbo-q5_0' }, + // { name: 'large-v3-turbo-q8_0' }, ] as const const modelNameMap = modelList.reduce((acc, model) => { @@ -70,7 +70,6 @@ const styles = StyleSheet.create({ modelItem: { backgroundColor: '#333', borderRadius: 5, - padding: 5, margin: 4, flexDirection: 'row', alignItems: 'center', @@ -79,6 +78,7 @@ const styles = StyleSheet.create({ backgroundColor: '#aaa', }, modelItemText: { + margin: 6, color: '#ccc', fontSize: 12, fontWeight: 'bold', @@ -98,47 +98,230 @@ const styles = StyleSheet.create({ fontSize: 12, fontWeight: 'bold', }, + progressBar: { + backgroundColor: '#3388ff', + position: 'absolute', + left: 0, + top: 0, + bottom: 0, + opacity: 0.5, + }, + logContainer: { + backgroundColor: 'lightgray', + padding: 8, + width: '95%', + borderRadius: 8, + marginVertical: 8, + }, + logText: { fontSize: 12, color: '#333' }, + buttonContainer: { + flexDirection: 'row', + justifyContent: 'center', + }, }) +const Model = (props: { + model: (typeof modelList)[number] + state: 'select' | 'download' + downloadMap: Record + setDownloadMap: (downloadMap: Record) => void + onDownloadStarted: (modelName: string) => void + onDownloaded: (modelName: string) => void +}) => { + const { model, state, downloadMap, setDownloadMap, onDownloadStarted, onDownloaded } = props + + const downloadRef = useRef(null) + const [progress, setProgress] = useState(0) + + const downloadNeeded = downloadMap[model.name] + + const cancelDownload = async () => { + if (downloadRef.current) { + RNFS.stopDownload(downloadRef.current) + downloadRef.current = null + setProgress(0) + } + } + + useEffect(() => { + if (state !== 'select') return + RNFS.exists(`${fileDir}/ggml-${model.name}.bin`).then((exists) => { + if (exists) setProgress(1) + else setProgress(0) + }) + }, [model.name, state]) + + useEffect(() => { + if (state === 'download') { + const download = async () => { + if (!downloadNeeded) return cancelDownload() + if (await RNFS.exists(`${fileDir}/ggml-${model.name}.bin`)) { + setProgress(1) + onDownloaded(model.name) + return + } + console.log('[Model] download', `${baseURL}${model.name}.bin?download=true`) + const { jobId, promise } = RNFS.downloadFile({ + fromUrl: `${baseURL}ggml-${model.name}.bin?download=true`, + toFile: `${fileDir}/ggml-${model.name}.bin`, + begin: () => { + setProgress(0) + onDownloadStarted(model.name) + }, + progress: (res) => { + setProgress(res.bytesWritten / res.contentLength) + }, + }) + downloadRef.current = jobId + promise.then(() => { + setProgress(1) + onDownloaded(model.name) + }) + } + download() + } else { + cancelDownload() + } + }, [state, downloadNeeded, model.name, onDownloadStarted, onDownloaded]) + + return ( + { + if (downloadRef.current) { + cancelDownload() + return + } + if (state === 'download') return + setDownloadMap({ + ...downloadMap, + [model.name]: !downloadMap[model.name], + }) + }} + > + {model.name} + + {downloadNeeded && ( + + )} + + ) +} + export default function Bench() { const whisperContextRef = useRef(null) const whisperContext = whisperContextRef.current - const [logs, setLogs] = useState([]) - const [downloadMap, setDownloadMap] = useState>(modelNameMap) - const downloadCount = Object.keys(downloadMap).filter((key) => downloadMap[key]).length + const [logs, setLogs] = useState([]) + const [downloadMap, setDownloadMap] = + useState>(modelNameMap) + const [modelState, setModelState] = useState<'select' | 'download'>('select') + + const downloadedModelsRef = useRef([]) + + + const log = useCallback((...messages: any[]) => { + setLogs((prev) => [...prev, messages.join(' ')]) + }, []) + + useEffect(() => { + if ( + downloadedModelsRef.current.length === + Object.values(downloadMap).filter(Boolean).length + ) { + downloadedModelsRef.current = [] + setModelState('select') + log('All models downloaded') + } + }, [log, logs, downloadMap]) + + const handleDownloadStarted = useCallback( + (modelName: string) => { + log(`Downloading ${modelName}`) + }, + [log], + ) + + const handleDownloaded = useCallback( + (modelName: string) => { + downloadedModelsRef.current = [...downloadedModelsRef.current, modelName] + log(`Downloaded ${modelName}`) + }, + [log], + ) + + const downloadCount = Object.keys(downloadMap).filter( + (key) => downloadMap[key], + ).length return ( - + Download List {modelList.map((model) => ( - { - setDownloadMap({ - ...downloadMap, - [model.name]: !downloadMap[model.name], - }) - }} - > - {model.name} - + state={modelState} + model={model} + downloadMap={downloadMap} + setDownloadMap={setDownloadMap} + onDownloadStarted={handleDownloadStarted} + onDownloaded={handleDownloaded} + /> ))} { + if (modelState === 'select') { + downloadedModelsRef.current = [] + setModelState('download') + } else { + setModelState('select') + } }} > - {`Download ${downloadCount} models`} + + {`${ + modelState === 'select' ? 'Download' : 'Cancel' + } ${downloadCount} models`} + - { - }} - > + {}}> Run benchmark + + {logs.map((msg, index) => ( + + {msg} + + ))} + + + { + setLogs([]) + }}> + Clear Logs + + { + setModelState('select') + RNFS.readDir(fileDir).then((files) => { + files.forEach((file) => { + if (file.name.startsWith('ggml-')) RNFS.unlink(file.path) + }) + }) + }} + > + Clear Downloaded Models + + ) }