From d6f6f49cde72fff809487d01204b018a7aaa9e73 Mon Sep 17 00:00:00 2001 From: Rusyaidi Date: Sun, 5 Nov 2023 13:45:17 +0800 Subject: [PATCH] Initial working adventure mode. --- app/CharInfo.js | 22 +++++ app/Settings.js | 25 +++++- app/index.js | 98 ++++++++++++++++++++-- components/ChatMenu/ChatWindow/ChatItem.js | 3 +- constants/global.tsx | 12 ++- lib/Inference.js | 76 ++++++++++++++--- 6 files changed, 211 insertions(+), 25 deletions(-) diff --git a/app/CharInfo.js b/app/CharInfo.js index b1bbe10..01e24fc 100644 --- a/app/CharInfo.js +++ b/app/CharInfo.js @@ -135,6 +135,23 @@ const CharInfo = () => { numberOfLines={8} /> + + Adventure Options + eg. option1 || option2 || option3 + + { + if(characterCard.spec !== undefined && characterCard.spec === 'chara_card_v2') + setCharacterCard({...characterCard, adventure_options: mes, data: {...characterCard.data, adventure_options: mes} }) + else + setCharacterCard({...characterCard, adventure_options: mes }) + }} + value={characterCard?.data?.adventure_options ?? characterCard?.adventure_options} + numberOfLines={8} + /> + ) @@ -199,6 +216,11 @@ const styles = StyleSheet.create({ paddingBottom: 8, }, + boxTextGray:{ + color:Color.Offwhite, + paddingBottom: 8, + }, + input: { color: Color.Text, textAlignVertical: 'top', diff --git a/app/Settings.js b/app/Settings.js index 2862917..cb78c0e 100644 --- a/app/Settings.js +++ b/app/Settings.js @@ -1,14 +1,15 @@ import { SafeAreaView, View, Text, Image, StyleSheet } from 'react-native' import { Global, Color, API, getUserImageDirectory } from '@globals' import React from 'react' -import { useMMKVString } from 'react-native-mmkv' -import { TouchableOpacity } from 'react-native-gesture-handler' +import { useMMKVBoolean, useMMKVString } from 'react-native-mmkv' +import { Switch, TouchableOpacity } from 'react-native-gesture-handler' import { FontAwesome } from '@expo/vector-icons' import { useRouter } from 'expo-router' const Settings = () => { const [userName, setUserName] = useMMKVString(Global.CurrentUser) const [apiType, setAPIType] = useMMKVString(Global.APIType) + const [adventureMode, setAdventureMode] = useMMKVBoolean(Global.AdventureEnabled) const router = useRouter() return ( @@ -48,6 +49,19 @@ const Settings = () => { + + { (apiType === API.KAI) && ( + + + Adventure Mode + + )} + ) @@ -123,4 +137,11 @@ const styles = StyleSheet.create({ borderColor: Color.Offwhite, }, + switchContainer : { + marginTop: 20, + alignItems: 'center', + flexDirection: 'row', + marginHorizontal: 16, + } + }) \ No newline at end of file diff --git a/app/index.js b/app/index.js index 493ff5c..38d80aa 100644 --- a/app/index.js +++ b/app/index.js @@ -1,7 +1,9 @@ import { View, Text, TextInput, SafeAreaView, TouchableOpacity, - StyleSheet + StyleSheet, + ToastAndroid, + ActivityIndicator } from 'react-native' import { useState, useEffect} from 'react' import {ChatWindow }from '@components/ChatMenu/ChatWindow/ChatWindow' @@ -23,12 +25,17 @@ const Home = () => { const [nowGenerating, setNowGenerating] = useMMKVBoolean(Global.NowGenerating) // Instruct const [currentInstruct, setCurrentInstruct] = useMMKVObject(Global.CurrentInstruct) + + // api / adventure + const [apiType, setApiType] = useMMKVString(Global.APIType) + const [adventureMode, setAdventureMode] = useMMKVBoolean(Global.AdventureEnabled) // Local const [messages, setMessages] = useState([]); const [newMessage, setNewMessage] = useState(''); const [targetLength, setTargetLength] = useState(0) const [abortFunction, setAbortFunction] = useState(undefined) + // load character chat upon character change useEffect(() => { if (charName === 'Welcome' || charName === undefined) return @@ -42,6 +49,10 @@ const Home = () => { }, [charName]) + useEffect(() => { + //setAdventureMode(false) + }, [apiType]) + // triggers generation when set true // TODO : Use this to save instead useEffect(() => { @@ -68,7 +79,7 @@ const Home = () => { generateResponse(setAbortFunction, insertGeneratedMessage, messages) } - const handleSend = () => { + const handleSend = (newMessage) => { if (newMessage.trim() !== ''){ const newMessageItem = createChatEntry(userName, true, newMessage) setMessages(messages => [...messages, newMessageItem]) @@ -78,18 +89,32 @@ const Home = () => { setNowGenerating(true) } - const insertGeneratedMessage = (data) => { + const insertGeneratedMessage = (input_data) => { + let data = "" + let adventure_data = "" + if(adventureMode) { + const filtered = input_data.split("```") + data = filtered[0] + if(filtered.length > 1) + adventure_data = filtered[1].split('\n').filter(item => {return item.startsWith(`[`)}).map(text => {return text.split(`] `)[1]}).join('||') + } + else data = input_data + setMessages(messages => { try { const createnew = (messages.length < targetLength) const mescontent = ((createnew ) ? data : messages.at(-1).mes + data) .replaceAll(currentInstruct.input_sequence, ``) .replaceAll(currentInstruct.output_sequence, ``) - const newmessage = (createnew) ? createChatEntry(charName, false, "") : messages.at(-1) - newmessage.mes = mescontent + let newmessage = { + ...((createnew) ? createChatEntry(charName, false, "") : messages.at(-1)), + mes : mescontent, + gen_finished:Date() , + adventure_options : adventure_data + } newmessage.swipes[newmessage.swipe_id] = mescontent - newmessage.gen_finished = Date() newmessage.swipe_info[newmessage.swipe_id].gen_finished = Date() + newmessage.swipe_info[newmessage.swipe_id].adventure_options = adventure_data const finalized_messages = createnew ? [...messages , newmessage] : [...messages.slice(0,-1), newmessage] return finalized_messages } catch (error) { @@ -99,6 +124,16 @@ const Home = () => { }) } + const getAdventureOptions = () => { + if(messages.length === 0|| messages.at(-1).name !== charName || messages.at(-1)?.adventure_options === undefined) return [] + try { + return messages.at(-1)?.adventure_options.split("||") ?? messages.at(-1)?.data?.adventure_options.split("||") + } catch { + ToastAndroid.show(`Something is wrong with Options formatting`, 2000) + return [] + } + } + const abortResponse = () => { console.log(`Aborting Generation`) if(abortFunction !== undefined) @@ -144,6 +179,41 @@ const Home = () => { + {(adventureMode) ? + (nowGenerating ? + + + + : + (messages.at(-1).name === charName && + + {(messages.at(-1)?.adventure_options !== undefined && messages.at(-1)?.adventure_options !== '') ? + getAdventureOptions().map((text, index) => ( + + { + setNewMessage(text) + setTargetLength(messages.length + 1) + handleSend(text) + }} + > + {text} + + )) + : + + { + setTargetLength(messages.length) + setNowGenerating(true) + }} + > + Generate Responses + + + } + )) + : @@ -176,12 +246,13 @@ const Home = () => { : - + handleSend(newMessage)}> } + } } @@ -257,6 +328,19 @@ const styles = StyleSheet.create({ color: Color.Text, marginLeft: 16, }, + + adventureInput: { + marginBottom: 8, + }, + + adventureOptionContainer: { + marginHorizontal: 16, + marginVertical: 2, + padding: 12, + borderRadius: 16, + backgroundColor: Color.DarkContainer, + justifyContent: 'center', + }, }); export default Home; diff --git a/components/ChatMenu/ChatWindow/ChatItem.js b/components/ChatMenu/ChatWindow/ChatItem.js index af56e9e..97cc726 100644 --- a/components/ChatMenu/ChatWindow/ChatItem.js +++ b/components/ChatMenu/ChatWindow/ChatItem.js @@ -192,6 +192,7 @@ const ChatItem = ({ message, id, scroll}) => { newmessages.at(id + 1).send_date = messages.at(id + 1).swipe_info.at(swipeid).send_date newmessages.at(id + 1).gen_started = messages.at(id + 1).swipe_info.at(swipeid).gen_started newmessages.at(id + 1).gen_finished = messages.at(id + 1).swipe_info.at(swipeid).gen_finished + newmessages.at(id + 1).adventure_options = messages.at(id + 1).swipe_info.at(swipeid).adventure_options newmessages.at(id + 1).swipe_id = swipeid saveChatFile(newmessages) @@ -220,10 +221,10 @@ const ChatItem = ({ message, id, scroll}) => { newmessages.at(id + 1).send_date = messages.at(id + 1).swipe_info.at(swipeid).send_date newmessages.at(id + 1).gen_started = messages.at(id + 1).swipe_info.at(swipeid).gen_started newmessages.at(id + 1).gen_finished = messages.at(id + 1).swipe_info.at(swipeid).gen_finished + newmessages.at(id + 1).adventure_options = messages.at(id + 1).swipe_info.at(swipeid).adventure_options newmessages.at(id + 1).swipe_id = swipeid - saveChatFile(newmessages) return newmessages }) diff --git a/constants/global.tsx b/constants/global.tsx index bcd763b..a41a181 100644 --- a/constants/global.tsx +++ b/constants/global.tsx @@ -87,6 +87,11 @@ export const enum Global { NovelModel='novelmodel', // novelai model AphroditeKey = 'aphroditekey', // api key for aphrodite, default is `EMPTY` + + // ADVENTURE + + AdventureEnabled = `adventureEnabled`, + AdventureSettings = 'adventuresettings', } export const enum API { @@ -393,7 +398,7 @@ const createNewChat = (userName : any, characterName : any, initmessage : any) = {"name":characterName,"is_user":false,"send_date":humanizedISO8601DateTime(), "mes":initmessage .replaceAll(`{{char}}`, mmkv.getString(Global.CurrentCharacter)) - .replaceAll(`{{user}}`, mmkv.getString(Global.CurrentUser)) + .replaceAll(`{{user}}`, mmkv.getString(Global.CurrentUser)) }, ] } @@ -411,7 +416,8 @@ export const createNewDefaultChat = ( {encoding: FS.EncodingType.UTF8}) .then( response => { let card = JSON.parse(response) - const newmessage = createNewChat(userName, charName, ( card?.data?.first_mes ?? card.first_mes )) + let newmessage : any = createNewChat(userName, charName, ( card?.data?.first_mes ?? card.first_mes )) + newmessage[1].adventure_options = card?.adventure_options ?? card?.data?.adventure_options ?? "" return FS.writeAsStringAsync( `${FS.documentDirectory}characters/${charName}/chats/${newmessage[0].create_date}.jsonl`, newmessage.map((item: any)=> JSON.stringify(item)).join('\u000d\u000a'), @@ -510,6 +516,7 @@ export const createChatEntry = (name : string, is_user : string, message : strin "extra":{"api":api,"model":model}, "swipe_id":0, "swipes":[message], + "adventure_options" : "", "swipe_info":[ // metadata { @@ -517,6 +524,7 @@ export const createChatEntry = (name : string, is_user : string, message : strin "extra":{"api":api,"model":model}, "gen_started" : new Date(), "gen_finished" : new Date(), + "adventure_options" : "", }, ], } diff --git a/lib/Inference.js b/lib/Inference.js index b8aad43..a5a2b2a 100644 --- a/lib/Inference.js +++ b/lib/Inference.js @@ -17,11 +17,15 @@ export const generateResponse = (setAbortFunction, insertGeneratedMessage, messa console.log(`Obtaining response.`) const APIType = getString(Global.APIType) - + const adventureMode = getBool(Global.AdventureEnabled) + console.log(adventureMode) try { + if(adventureMode) + KAIresponse(setAbortFunction, insertGeneratedMessage, messages) + else switch(APIType) { case API.KAI: - KAIresponse(setAbortFunction, insertGeneratedMessage, messages) + KAIresponseStream(setAbortFunction, insertGeneratedMessage, messages) break case API.HORDE: hordeResponse(setAbortFunction, insertGeneratedMessage, messages) @@ -37,10 +41,15 @@ export const generateResponse = (setAbortFunction, insertGeneratedMessage, messa } } -const adventure_grammar = `text ::= [^\\n]+ "\\n" -action ::= [^\\n"."]+ ".\\n" -emotion ::= "[" ("SAD" | "HAPPY" | "NEUTRAL" | "ANGRY" | "FLIRT" | "CONFUSED") "] " -root ::= text text? "\`\`\`\\n{{Actions}}:\\n" emotion action emotion action emotion action ` +const adventure_grammar = () => { + + let emotions = ["SAD" , "HAPPY" , "NEUTRAL" , "ANGRY" , "CONFUSED"] + emotions.sort() + + return `text ::= [^\\n]+ "\\n" + action ::= [^\\n^\".\"]+ ".\\n" + root ::= "\`\`\`\\n{{Actions}}:\\n" "[${emotions[0]}] " action "[${emotions[1]}] " action "[${emotions[2]}] " action ` +} // MMKV @@ -52,6 +61,10 @@ const getString = (key) => { return mmkv.getString(key) ?? "" } +const getBool = (key) => { + return mmkv.getBoolean(key) +} + const setValue = (key, value) => { mmkv.set(key ,value) } @@ -79,8 +92,9 @@ const buildContext = (max_length, messages) => { } message_acc += message_shard } - if (messages.at(-1).name === charName) { - payload += message_acc + currentInstruct.output_sequence + if (messages.at(-1).name !== charName) { + console.log(`Adding output sequence`) + payload += message_acc + `\n` + currentInstruct.output_sequence } else { payload += message_acc.trim('\n') @@ -93,10 +107,10 @@ const buildContext = (max_length, messages) => { // Payloads const constructKAIPayload = (messages) => { - + const adventureMode = getBool(Global.AdventureEnabled) const presetKAI = getObject(Global.PresetKAI) const currentInstruct = getObject(Global.CurrentInstruct) - + const charName = getString(Global.CurrentCharacter) return { "prompt": buildContext(presetKAI.max_context_length, messages), "use_story": false, @@ -125,7 +139,7 @@ const constructKAIPayload = (messages) => { "mirostat_tau": presetKAI.mirostat_tau, "mirostat_eta": presetKAI.mirostat_eta, "min_p" : presetKAI.min_p, - "grammar" : adventure_grammar + "grammar" : (adventureMode && messages.at(-1).name === charName && messages.at(-1).adventure_options === '' ) ? adventure_grammar(): presetKAI.grammar } } @@ -223,14 +237,14 @@ const constructTGWUIPayload = (messages) => { // Fetch Response -const KAIresponse = (setAbortFunction, insertGeneratedMessage, messages) => { +const KAIresponseStream = (setAbortFunction, insertGeneratedMessage, messages) => { const kaiendpoint = getString(Global.KAIEndpoint) const controller = new AbortController(); const timeout = setTimeout(() => controller.abort(), 60000); let aborted = false - console.log(`Using KAI`) + console.log(`Using KAI Stream`) setAbortFunction(abortFunction => () => { controller.abort() aborted = true @@ -274,6 +288,42 @@ const KAIresponse = (setAbortFunction, insertGeneratedMessage, messages) => { }) } +const KAIresponse = async (setAbortFunction, insertGeneratedMessage, messages) => { + const kaiendpoint = getString(Global.KAIEndpoint) + const controller = new AbortController(); + // const timeout = setTimeout(() => controller.abort(), 60000); + let aborted = false + + console.log(`Using KAI Stream`) + setAbortFunction(abortFunction => () => { + controller.abort() + aborted = true + axios + .create({timeout: 1000}) + .post(`${kaiendpoint}/api/extra/abort`) + .catch(() => {ToastAndroid.show(`Abort Failed`, 2000)}) + }) + + fetch(`${kaiendpoint}/api/v1/generate`, { + reactNative: {textStreaming: true}, + method: `POST`, + body: JSON.stringify(constructKAIPayload(messages)), + signal: controller.signal, + }, {}).then( + (response) => {return response.json()} + ).then((response) => { + setValue(Global.NowGenerating, false) + insertGeneratedMessage(response.results[0].text ?? '') + }).catch((error) => { + setValue(Global.NowGenerating, false) + if(!aborted) + ToastAndroid.show('Connection Lost...', ToastAndroid.SHORT) + console.log('KAI Response failed: ' + error) + }) + +} + + const hordeResponse = async (setAbortFunction, insertGeneratedMessage, messages) => { const hordeKey = getString(Global.HordeKey)