diff --git a/electron/main/index.ts b/electron/main/index.ts index 982f986f..229b7492 100644 --- a/electron/main/index.ts +++ b/electron/main/index.ts @@ -26,6 +26,7 @@ import { registerLLMSessionHandlers } from "./llm/llmSessionHandlers"; import { FileInfoTree } from "./Files/Types"; import { registerDBSessionHandlers } from "./database/dbSessionHandlers"; import { validateAIModelConfig } from "./llm/llmConfig"; +import AIModelManager from "@/components/Settings/LLMSettings"; const store = new Store(); // const user = store.get("user"); @@ -58,35 +59,43 @@ if (!app.requestSingleInstanceLock()) { const markdownExtensions = [".md", ".markdown", ".mdown", ".mkdn", ".mkd"]; -const defaultAIModels = { +const defaultAIModels: { [modelName: string]: AIModelConfig } = { "gpt-3.5-turbo-1106": { - localpath: "", - contextlength: 16385, + localPath: "", + contextLength: 16385, engine: "openai", }, "gpt-4-1106-preview": { - localpath: "", - contextlength: 128000, + localPath: "", + contextLength: 128000, engine: "openai", }, "gpt-4-0613": { - localpath: "", - contextlength: 8192, + localPath: "", + contextLength: 8192, engine: "openai", }, "gpt-4-32k-0613": { - localpath: "", - contextlength: 32768, + localPath: "", + contextLength: 32768, engine: "openai", }, }; -const currentAIModels = store.get(StoreKeys.AIModels); -const hasModelsChanged = - JSON.stringify(currentAIModels) !== JSON.stringify(defaultAIModels); +const currentAIModels = + (store.get(StoreKeys.AIModels) as Record) || {}; -if (!currentAIModels || hasModelsChanged) { - store.set(StoreKeys.AIModels, defaultAIModels); +// Merge default models with existing ones +const updatedModels = { ...currentAIModels }; +for (const [modelName, modelConfig] of Object.entries(defaultAIModels)) { + if (!updatedModels[modelName]) { + updatedModels[modelName] = modelConfig; + } +} + +// Save the updated models if they are different from the current models +if (JSON.stringify(currentAIModels) !== JSON.stringify(updatedModels)) { + store.set(StoreKeys.AIModels, updatedModels); } // Remove electron security warnings diff --git a/electron/main/llm/llmSessionHandlers.ts b/electron/main/llm/llmSessionHandlers.ts index b9350a93..69e6d264 100644 --- a/electron/main/llm/llmSessionHandlers.ts +++ b/electron/main/llm/llmSessionHandlers.ts @@ -1,5 +1,5 @@ import { ipcMain, IpcMainInvokeEvent } from "electron"; -import { LlamaCPPModelLoader, LlamaCPPSessionService } from "./models/LlamaCpp"; // Assuming SessionService is in the same directory +import { LlamaCPPSessionService } from "./models/LlamaCpp"; // Assuming SessionService is in the same directory import { ISessionService } from "./Types"; import { OpenAIModel, OpenAIModelSessionService } from "./models/OpenAI"; import { AIModelConfig, StoreKeys, StoreSchema } from "../Config/storeConfig"; @@ -13,13 +13,13 @@ const sessions: { [sessionId: string]: ISessionService } = {}; export const registerLLMSessionHandlers = (store: Store) => { const openAIAPIKey: string = store.get(StoreKeys.UserOpenAIAPIKey); - const llamaCPPModelLoader = new LlamaCPPModelLoader(); - llamaCPPModelLoader.loadModel(); + // const llamaCPPModelLoader = new LlamaCPPModelLoader(); + // llamaCPPModelLoader.loadModel(); // const gpt4SessionService = new GPT4SessionService(gpt4Model, webContents); // await gpt4SessionService.init(); ipcMain.handle( - "getOrCreateSession", + "get-or-create-session", async (event: IpcMainInvokeEvent, sessionId: string) => { if (sessions[sessionId]) { return sessionId; @@ -37,7 +37,7 @@ export const registerLLMSessionHandlers = (store: Store) => { sessions[sessionId] = sessionService; return sessionId; } else { - const sessionService = new LlamaCPPSessionService(llamaCPPModelLoader); + const sessionService = new LlamaCPPSessionService(); sessions[sessionId] = sessionService; return sessionId; } diff --git a/electron/main/llm/models/LlamaCpp.ts b/electron/main/llm/models/LlamaCpp.ts index 6127a760..b4497690 100644 --- a/electron/main/llm/models/LlamaCpp.ts +++ b/electron/main/llm/models/LlamaCpp.ts @@ -2,55 +2,58 @@ import path from "path"; import os from "os"; import { IModel, ISendFunctionImplementer, ISessionService } from "../Types"; -export class LlamaCPPModelLoader implements IModel { - private model: any; +export class LlamaCPPSessionService implements ISessionService { + private session: any; + public context: any; + private model: any; // Model instance - async loadModel(): Promise { - // Load model logic - this.model = await import("node-llama-cpp").then((nodeLLamaCpp: any) => { - return new nodeLLamaCpp.LlamaModel({ - modelPath: path.join( - os.homedir(), - "Downloads", - "tinyllama-2-1b-miniguanaco.Q2_K.gguf" - // "mistral-7b-v0.1.Q4_K_M.gguf" - ), - gpuLayers: 0, - }); - }); + constructor() { + this.init(); } - public async getModel(): Promise { - if (!this.model) { - throw new Error("Model not initialized"); - } - return this.model; + private async loadModel(): Promise { + // Load model logic - similar to what was in LlamaCPPModelLoader + const nodeLLamaCpp = await import("node-llama-cpp"); + this.model = new nodeLLamaCpp.LlamaModel({ + modelPath: path.join( + os.homedir(), + "Downloads", + "tinyllama-2-1b-miniguanaco.Q2_K.gguf" + // "mistral-7b-v0.1.Q4_K_M.gguf" + ), + gpuLayers: 0, + }); + + // this.model = await import("node-llama-cpp").then((nodeLLamaCpp: any) => { + // return new nodeLLamaCpp.LlamaModel({ + // modelPath: path.join( + // os.homedir(), + // "Downloads", + // "tinyllama-2-1b-miniguanaco.Q2_K.gguf" + // // "mistral-7b-v0.1.Q4_K_M.gguf" + // ), + // gpuLayers: 0, + // }); + // }); } - async unloadModel(): Promise { + private async unloadModel(): Promise { // Unload model logic this.model = null; } - isModelLoaded(): boolean { + private isModelLoaded(): boolean { return !!this.model; } -} - -export class LlamaCPPSessionService implements ISessionService { - private session: any; - public context: any; - private modelLoader: LlamaCPPModelLoader; - constructor(modelLoader: LlamaCPPModelLoader) { - this.modelLoader = modelLoader; - this.init(); - } async init() { - // eslint-disable-next-line @typescript-eslint/no-var-requires + await this.loadModel(); + if (!this.isModelLoaded()) { + throw new Error("Model not loaded"); + } + import("node-llama-cpp").then(async (nodeLLamaCpp: any) => { - const model = await this.modelLoader.getModel(); - this.context = new nodeLLamaCpp.LlamaContext({ model }); + this.context = new nodeLLamaCpp.LlamaContext({ model: this.model }); this.session = new nodeLLamaCpp.LlamaChatSession({ context: this.context, }); @@ -71,7 +74,12 @@ export class LlamaCPPSessionService implements ISessionService { topP: 0.02, onToken: (chunk: any[]) => { const decodedChunk = this.context.decode(chunk); - sendFunctionImplementer.send("tokenStream", decodedChunk); + console.log("decoded chunk: ", decodedChunk); + sendFunctionImplementer.send("tokenStream", { + messageType: "success", + message: decodedChunk, + }); + // sendFunctionImplementer.send("tokenStream", decodedChunk); }, }); } diff --git a/electron/main/llm/models/OpenAI.ts b/electron/main/llm/models/OpenAI.ts index ff8a095c..0d521823 100644 --- a/electron/main/llm/models/OpenAI.ts +++ b/electron/main/llm/models/OpenAI.ts @@ -89,14 +89,3 @@ export class OpenAIModelSessionService implements ISessionService { } } } - -// const fetchOpenAIModels = async (apiKey: string) => { -// try { -// const openai = new OpenAI({ apiKey }); -// const modelsResponse = await openai.models.list(); -// return modelsResponse.data; -// } catch (error) { -// console.error("Error fetching models from OpenAI:", error); -// return []; -// } -// }; diff --git a/electron/preload/index.ts b/electron/preload/index.ts index f31545c2..18dcf4da 100644 --- a/electron/preload/index.ts +++ b/electron/preload/index.ts @@ -163,7 +163,7 @@ contextBridge.exposeInMainWorld("llm", { // return await ipcRenderer.invoke("createSession", sessionId); // }, getOrCreateSession: async (sessionId: any) => { - return await ipcRenderer.invoke("getOrCreateSession", sessionId); + return await ipcRenderer.invoke("get-or-create-session", sessionId); }, getHello: async (sessionId: any) => { return await ipcRenderer.invoke("getHello", sessionId); diff --git a/src/components/Chat/Chat.tsx b/src/components/Chat/Chat.tsx index 15f46fa1..aae7c79c 100644 --- a/src/components/Chat/Chat.tsx +++ b/src/components/Chat/Chat.tsx @@ -11,6 +11,21 @@ const ChatWithLLM: React.FC = () => { const [currentBotMessage, setCurrentBotMessage] = useState(null); + const initializeSession = async () => { + setLoading(true); + try { + console.log("Creating a new session..."); + const newSessionId = await window.llm.getOrCreateSession( + "some_unique_session_id" + ); + console.log("Created a new session with id:", newSessionId); + setSessionId(newSessionId); + } catch (error) { + console.error("Failed to create a new session:", error); + } finally { + setLoading(false); + } + }; useEffect(() => { if (!sessionId) { initializeSession(); @@ -36,20 +51,6 @@ const ChatWithLLM: React.FC = () => { } }, [sessionId]); - const initializeSession = async () => { - setLoading(true); - try { - const newSessionId = await window.llm.getOrCreateSession( - "some_unique_session_id" - ); - setSessionId(newSessionId); - } catch (error) { - console.error("Failed to create a new session:", error); - } finally { - setLoading(false); - } - }; - const handleSubmitNewMessage = async () => { if (currentBotMessage) { setMessages((prevMessages) => [ diff --git a/src/components/Settings/Settings.tsx b/src/components/Settings/Settings.tsx index 39ef0d88..33847f60 100644 --- a/src/components/Settings/Settings.tsx +++ b/src/components/Settings/Settings.tsx @@ -60,9 +60,6 @@ const SettingsModal: React.FC = ({ isOpen, onClose }) => {
- {/* - */} -