Skip to content

Commit

Permalink
feat(example): download models
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 7, 2024
1 parent a210e88 commit 492a6dc
Showing 1 changed file with 224 additions and 41 deletions.
265 changes: 224 additions & 41 deletions example/src/Bench.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useCallback, useRef, useState } from 'react'
import React, { useCallback, useEffect, useRef, useState } from 'react'
import {
StyleSheet,
ScrollView,
Expand All @@ -17,27 +17,27 @@ const modelList = [
{ name: 'tiny', coreml: true },
{ name: 'tiny-q5_1' },
{ name: 'tiny-q8_0' },
{ name: 'base', coreml: true },
{ name: 'base-q5_1' },
{ name: 'base-q8_0' },
{ name: 'small', coreml: true },
{ name: 'small-q5_1' },
{ name: 'small-q8_0' },
{ name: 'medium', coreml: true },
{ name: 'medium-q5_0' },
{ name: 'medium-q8_0' },
// { name: 'base', coreml: true },
// { name: 'base-q5_1' },
// { name: 'base-q8_0' },
// { name: 'small', coreml: true },
// { name: 'small-q5_1' },
// { name: 'small-q8_0' },
// { name: 'medium', coreml: true },
// { name: 'medium-q5_0' },
// { name: 'medium-q8_0' },
// { name: 'large-v1', coreml: true },
// { name: 'large-v1-q5_0', },
// { name: 'large-v1-q8_0', },
{ name: 'large-v2', coreml: true },
{ name: 'large-v2-q5_0' },
{ name: 'large-v2-q8_0' },
{ name: 'large-v3', coreml: true },
{ name: 'large-v3-q5_0' },
{ name: 'large-v3-q8_0' },
{ name: 'large-v3-turbo', coreml: true },
{ name: 'large-v3-turbo-q5_0' },
{ name: 'large-v3-turbo-q8_0' },
// { name: 'large-v2', coreml: true },
// { name: 'large-v2-q5_0' },
// { name: 'large-v2-q8_0' },
// { name: 'large-v3', coreml: true },
// { name: 'large-v3-q5_0' },
// { name: 'large-v3-q8_0' },
// { name: 'large-v3-turbo', coreml: true },
// { name: 'large-v3-turbo-q5_0' },
// { name: 'large-v3-turbo-q8_0' },
] as const

const modelNameMap = modelList.reduce((acc, model) => {
Expand Down Expand Up @@ -70,7 +70,6 @@ const styles = StyleSheet.create({
modelItem: {
backgroundColor: '#333',
borderRadius: 5,
padding: 5,
margin: 4,
flexDirection: 'row',
alignItems: 'center',
Expand All @@ -79,6 +78,7 @@ const styles = StyleSheet.create({
backgroundColor: '#aaa',
},
modelItemText: {
margin: 6,
color: '#ccc',
fontSize: 12,
fontWeight: 'bold',
Expand All @@ -98,47 +98,230 @@ const styles = StyleSheet.create({
fontSize: 12,
fontWeight: 'bold',
},
progressBar: {
backgroundColor: '#3388ff',
position: 'absolute',
left: 0,
top: 0,
bottom: 0,
opacity: 0.5,
},
logContainer: {
backgroundColor: 'lightgray',
padding: 8,
width: '95%',
borderRadius: 8,
marginVertical: 8,
},
logText: { fontSize: 12, color: '#333' },
buttonContainer: {
flexDirection: 'row',
justifyContent: 'center',
},
})

const Model = (props: {
model: (typeof modelList)[number]
state: 'select' | 'download'
downloadMap: Record<string, boolean>
setDownloadMap: (downloadMap: Record<string, boolean>) => void
onDownloadStarted: (modelName: string) => void
onDownloaded: (modelName: string) => void
}) => {
const { model, state, downloadMap, setDownloadMap, onDownloadStarted, onDownloaded } = props

const downloadRef = useRef<number | null>(null)
const [progress, setProgress] = useState(0)

const downloadNeeded = downloadMap[model.name]

const cancelDownload = async () => {
if (downloadRef.current) {
RNFS.stopDownload(downloadRef.current)
downloadRef.current = null
setProgress(0)
}
}

useEffect(() => {
if (state !== 'select') return
RNFS.exists(`${fileDir}/ggml-${model.name}.bin`).then((exists) => {
if (exists) setProgress(1)
else setProgress(0)
})
}, [model.name, state])

useEffect(() => {
if (state === 'download') {
const download = async () => {
if (!downloadNeeded) return cancelDownload()
if (await RNFS.exists(`${fileDir}/ggml-${model.name}.bin`)) {
setProgress(1)
onDownloaded(model.name)
return
}
console.log('[Model] download', `${baseURL}${model.name}.bin?download=true`)
const { jobId, promise } = RNFS.downloadFile({
fromUrl: `${baseURL}ggml-${model.name}.bin?download=true`,
toFile: `${fileDir}/ggml-${model.name}.bin`,
begin: () => {
setProgress(0)
onDownloadStarted(model.name)
},
progress: (res) => {
setProgress(res.bytesWritten / res.contentLength)
},
})
downloadRef.current = jobId
promise.then(() => {
setProgress(1)
onDownloaded(model.name)
})
}
download()
} else {
cancelDownload()
}
}, [state, downloadNeeded, model.name, onDownloadStarted, onDownloaded])

return (
<Pressable
key={model.name}
style={[
styles.modelItem,
!downloadMap[model.name] && styles.modelItemUnselected,
]}
onPress={() => {
if (downloadRef.current) {
cancelDownload()
return
}
if (state === 'download') return
setDownloadMap({
...downloadMap,
[model.name]: !downloadMap[model.name],
})
}}
>
<Text style={styles.modelItemText}>{model.name}</Text>

{downloadNeeded && (
<View style={[styles.progressBar, { width: `${progress * 100}%` }]} />
)}
</Pressable>
)
}

export default function Bench() {
const whisperContextRef = useRef<WhisperContext | null>(null)
const whisperContext = whisperContextRef.current
const [logs, setLogs] = useState([])
const [downloadMap, setDownloadMap] = useState<Record<string, boolean>>(modelNameMap)
const downloadCount = Object.keys(downloadMap).filter((key) => downloadMap[key]).length
const [logs, setLogs] = useState<string[]>([])
const [downloadMap, setDownloadMap] =
useState<Record<string, boolean>>(modelNameMap)
const [modelState, setModelState] = useState<'select' | 'download'>('select')

const downloadedModelsRef = useRef<string[]>([])


const log = useCallback((...messages: any[]) => {
setLogs((prev) => [...prev, messages.join(' ')])
}, [])

useEffect(() => {
if (
downloadedModelsRef.current.length ===
Object.values(downloadMap).filter(Boolean).length
) {
downloadedModelsRef.current = []
setModelState('select')
log('All models downloaded')
}
}, [log, logs, downloadMap])

const handleDownloadStarted = useCallback(
(modelName: string) => {
log(`Downloading ${modelName}`)
},
[log],
)

const handleDownloaded = useCallback(
(modelName: string) => {
downloadedModelsRef.current = [...downloadedModelsRef.current, modelName]
log(`Downloaded ${modelName}`)
},
[log],
)

const downloadCount = Object.keys(downloadMap).filter(
(key) => downloadMap[key],
).length
return (
<ScrollView style={styles.container} contentContainerStyle={styles.contentContainer}>
<ScrollView
style={styles.container}
contentContainerStyle={styles.contentContainer}
>
<Text style={styles.title}>Download List</Text>
<View style={styles.modelList}>
{modelList.map((model) => (
<Pressable
<Model
key={model.name}
style={[styles.modelItem, !downloadMap[model.name] && styles.modelItemUnselected]}
onPress={() => {
setDownloadMap({
...downloadMap,
[model.name]: !downloadMap[model.name],
})
}}
>
<Text style={styles.modelItemText}>{model.name}</Text>
</Pressable>
state={modelState}
model={model}
downloadMap={downloadMap}
setDownloadMap={setDownloadMap}
onDownloadStarted={handleDownloadStarted}
onDownloaded={handleDownloaded}
/>
))}
</View>
<Pressable
style={styles.button}
onPress={() => {
if (modelState === 'select') {
downloadedModelsRef.current = []
setModelState('download')
} else {
setModelState('select')
}
}}
>
<Text style={styles.buttonText}>{`Download ${downloadCount} models`}</Text>
<Text style={styles.buttonText}>
{`${
modelState === 'select' ? 'Download' : 'Cancel'
} ${downloadCount} models`}
</Text>
</Pressable>
<Pressable
style={styles.button}
onPress={() => {
}}
>
<Pressable style={styles.button} onPress={() => {}}>
<Text style={styles.buttonText}>Run benchmark</Text>
</Pressable>
<View style={styles.logContainer}>
{logs.map((msg, index) => (
<Text key={index} style={styles.logText}>
{msg}
</Text>
))}
</View>
<View style={styles.buttonContainer}>
<Pressable style={styles.button} onPress={() => {
setLogs([])
}}>
<Text style={styles.buttonText}>Clear Logs</Text>
</Pressable>
<Pressable
style={styles.button}
onPress={() => {
setModelState('select')
RNFS.readDir(fileDir).then((files) => {
files.forEach((file) => {
if (file.name.startsWith('ggml-')) RNFS.unlink(file.path)
})
})
}}
>
<Text style={styles.buttonText}>Clear Downloaded Models</Text>
</Pressable>
</View>
</ScrollView>
)
}

0 comments on commit 492a6dc

Please sign in to comment.