Skip to content

Commit

Permalink
feat: added loading model callback, updated cui-llama.rn
Browse files Browse the repository at this point in the history
  • Loading branch information
Vali-98 committed Sep 26, 2024
1 parent dd8a749 commit aaa8ee6
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 32 deletions.
87 changes: 66 additions & 21 deletions app/components/Endpoint/Local.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@ import {
} from 'react-native'
import { Dropdown } from 'react-native-element-dropdown'
import { useMMKVBoolean, useMMKVObject, useMMKVString } from 'react-native-mmkv'
import * as Progress from 'react-native-progress'

import { SliderItem } from '..'

const Local = () => {
const { loadModel, unloadModel, modelName } = Llama.useLlama((state) => ({
loadModel: state.load,
unloadModel: state.unload,
modelName: state.modelname,
}))
const { loadModel, unloadModel, modelName, loadProgress, setloadProgress } = Llama.useLlama(
(state) => ({
loadModel: state.load,
unloadModel: state.unload,
modelName: state.modelname,
loadProgress: state.loadProgress,
setloadProgress: state.setLoadProgress,
})
)

const [modelLoading, setModelLoading] = useState(false)
const [modelImporting, setModelImporting] = useState(false)
const [modelList, setModelList] = useState<string[]>([])
const dropdownValues = modelList.map((item) => {
return { name: item }
Expand All @@ -45,18 +51,11 @@ const Local = () => {

const handleLoad = async () => {
setModelLoading(true)
setloadProgress(0)
await loadModel(currentModel ?? '', preset)
setModelLoading(false)
getModels()
}
/*
const handleLoadExternal = async () => {
setModelLoading(true)
await Llama.loadModel('', preset, false).then(() => {
setLoadedModel(Llama.getModelname())
})
setModelLoading(false)
}*/

const handleDelete = async () => {
if (!(await Llama.modelExists(currentModel ?? ''))) {
Expand Down Expand Up @@ -94,10 +93,10 @@ const Local = () => {
}*/

const handleImport = async () => {
setModelLoading(true)
setModelImporting(true)
await Llama.importModel()
await getModels()
setModelLoading(false)
setModelImporting(false)
}

const disableLoad = modelList.length === 0 || modelName !== undefined
Expand Down Expand Up @@ -137,9 +136,52 @@ const Local = () => {
/>
</View>

{modelLoading ? (
<ActivityIndicator size="large" color={Style.getColor('primary-text1')} />
) : (
{!modelLoading && modelImporting && (
<View style={{ flexDirection: 'row', alignItems: 'center' }}>
<Progress.Bar
style={{ marginVertical: 16, flex: 5 }}
indeterminate
indeterminateAnimationDuration={2000}
color={Style.getColor('primary-brand')}
borderColor={Style.getColor('primary-surface3')}
height={12}
borderRadius={12}
width={null}
/>
<Text
style={{
flex: 2,
color: Style.getColor('primary-text1'),
textAlign: 'center',
}}>
Importing...
</Text>
</View>
)}

{modelLoading && !modelImporting && (
<View style={{ flexDirection: 'row', alignItems: 'center' }}>
<Progress.Bar
style={{ marginVertical: 16, flex: 5 }}
progress={loadProgress / 100}
color={Style.getColor('primary-brand')}
borderColor={Style.getColor('primary-surface3')}
height={12}
borderRadius={12}
width={null}
/>
<Text
style={{
flex: 1,
color: Style.getColor('primary-text1'),
textAlign: 'center',
}}>
{loadProgress}%
</Text>
</View>
)}

{!modelLoading && !modelImporting && (
<View style={{ flexDirection: 'row', marginTop: 8 }}>
<TouchableOpacity
disabled={disableLoad}
Expand Down Expand Up @@ -225,9 +267,10 @@ const Local = () => {
body={preset}
setValue={setPreset}
varname="context_length"
min={512}
min={1024}
max={32768}
step={512}
step={1024}
disabled={modelImporting || modelLoading}
/>
<SliderItem
name="Threads"
Expand All @@ -237,6 +280,7 @@ const Local = () => {
min={1}
max={8}
step={1}
disabled={modelImporting || modelLoading}
/>

<SliderItem
Expand All @@ -246,7 +290,8 @@ const Local = () => {
varname="batch"
min={16}
max={512}
step={1}
step={16}
disabled={modelImporting || modelLoading}
/>
{/* Note: llama.rn does not have any Android gpu acceleration */}
{Platform.OS === 'ios' && (
Expand Down
21 changes: 19 additions & 2 deletions app/components/SliderItem.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type SliderItemProps = {
step?: number
precision?: number
showInput?: boolean
disabled?: boolean
}

const SliderItem: React.FC<SliderItemProps> = ({
Expand All @@ -27,6 +28,7 @@ const SliderItem: React.FC<SliderItemProps> = ({
precision = 0,
onChange = undefined,
showInput = true,
disabled = false,
}) => {
const clamp = (val: number) => Math.min(Math.max(parseFloat(val?.toFixed(2) ?? 0), min), max)
const [textValue, setTextValue] = useState(body[varname]?.toFixed(precision))
Expand Down Expand Up @@ -57,9 +59,10 @@ const SliderItem: React.FC<SliderItemProps> = ({

return (
<View style={{ alignItems: `center` }}>
<Text style={styles.itemName}>{name}</Text>
<Text style={disabled ? styles.itemNameDisabled : styles.itemName}>{name}</Text>
<View style={styles.sliderContainer}>
<Slider
disabled={disabled}
style={styles.slider}
step={step}
minimumValue={min}
Expand All @@ -72,7 +75,8 @@ const SliderItem: React.FC<SliderItemProps> = ({
/>
{showInput && (
<TextInput
style={styles.textBox}
editable={disabled}
style={disabled ? styles.textBoxDisabled : styles.textBox}
value={textValue}
onChangeText={setTextValue}
onEndEditing={handleTextInputChange}
Expand All @@ -92,6 +96,10 @@ const styles = StyleSheet.create({
color: Style.getColor('primary-text1'),
},

itemNameDisabled: {
color: Style.getColor('primary-text3'),
},

sliderContainer: {
flexDirection: `row`,
},
Expand All @@ -109,4 +117,13 @@ const styles = StyleSheet.create({
flex: 1.5,
textAlign: `center`,
},

textBoxDisabled: {
borderColor: Style.getColor('primary-surface4'),
color: Style.getColor('primary-text3'),
borderWidth: 1,
borderRadius: 12,
flex: 1.5,
textAlign: `center`,
},
})
18 changes: 15 additions & 3 deletions app/constants/LlamaLocal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ type CompletionOutput = {
type LlamaState = {
context: LlamaContext | undefined
modelname: string | undefined
loadProgress: number
load: (name: string, preset?: LlamaPreset, usecache?: boolean) => Promise<void>
setLoadProgress: (progress: number) => void
unload: () => Promise<void>
saveKV: () => Promise<void>
loadKV: () => Promise<void>
Expand Down Expand Up @@ -64,6 +66,7 @@ export namespace Llama {
export const useLlama = create<LlamaState>()((set, get) => ({
context: undefined,
modelname: undefined,
loadProgress: 0,
load: async (
name: string,
preset: LlamaPreset = default_preset,
Expand Down Expand Up @@ -99,10 +102,16 @@ export namespace Llama {
}

mmkv.set(Global.LocalSessionLoaded, false)
Logger.log(`Loading Model: ${name}`, true)
Logger.log(JSON.stringify(params))
Logger.log(`Loading Model: ${name}`)
Logger.log(
`Starting with parameters: \nContext Length: ${params.n_ctx}\nThreads: ${params.n_threads}\nBatch Size: ${params.n_batch}`
)

const llamaContext = await initLlama(params).catch((error) => {
const progressCallback = (progress: number) => {
if (progress % 5 === 0) get().setLoadProgress(progress)
}

const llamaContext = await initLlama(params, progressCallback).catch((error) => {
Logger.log(`Could Not Load Model: ${error} `, true)
})

Expand All @@ -111,6 +120,9 @@ export namespace Llama {
Logger.log('Model Loaded', true)
}
},
setLoadProgress: (progress: number) => {
set((state) => ({ ...state, loadProgress: progress }))
},
unload: async () => {
await get().context?.release()
set((state) => ({ ...state, context: undefined, modelname: undefined }))
Expand Down
Loading

0 comments on commit aaa8ee6

Please sign in to comment.