Skip to content

Commit

Permalink
update logic to make llamacpp a once class object
Browse files Browse the repository at this point in the history
  • Loading branch information
samlhuillier committed Dec 5, 2023
1 parent 8847cca commit bff6362
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 84 deletions.
37 changes: 23 additions & 14 deletions electron/main/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { registerLLMSessionHandlers } from "./llm/llmSessionHandlers";
import { FileInfoTree } from "./Files/Types";
import { registerDBSessionHandlers } from "./database/dbSessionHandlers";
import { validateAIModelConfig } from "./llm/llmConfig";
import AIModelManager from "@/components/Settings/LLMSettings";

const store = new Store<StoreSchema>();
// const user = store.get("user");
Expand Down Expand Up @@ -58,35 +59,43 @@ if (!app.requestSingleInstanceLock()) {

const markdownExtensions = [".md", ".markdown", ".mdown", ".mkdn", ".mkd"];

const defaultAIModels = {
const defaultAIModels: { [modelName: string]: AIModelConfig } = {
"gpt-3.5-turbo-1106": {
localpath: "",
contextlength: 16385,
localPath: "",
contextLength: 16385,
engine: "openai",
},
"gpt-4-1106-preview": {
localpath: "",
contextlength: 128000,
localPath: "",
contextLength: 128000,
engine: "openai",
},
"gpt-4-0613": {
localpath: "",
contextlength: 8192,
localPath: "",
contextLength: 8192,
engine: "openai",
},
"gpt-4-32k-0613": {
localpath: "",
contextlength: 32768,
localPath: "",
contextLength: 32768,
engine: "openai",
},
};

const currentAIModels = store.get(StoreKeys.AIModels);
const hasModelsChanged =
JSON.stringify(currentAIModels) !== JSON.stringify(defaultAIModels);
const currentAIModels =
(store.get(StoreKeys.AIModels) as Record<string, AIModelConfig>) || {};

if (!currentAIModels || hasModelsChanged) {
store.set(StoreKeys.AIModels, defaultAIModels);
// Merge default models with existing ones
const updatedModels = { ...currentAIModels };
for (const [modelName, modelConfig] of Object.entries(defaultAIModels)) {
if (!updatedModels[modelName]) {
updatedModels[modelName] = modelConfig;
}
}

// Save the updated models if they are different from the current models
if (JSON.stringify(currentAIModels) !== JSON.stringify(updatedModels)) {
store.set(StoreKeys.AIModels, updatedModels);
}

// Remove electron security warnings
Expand Down
10 changes: 5 additions & 5 deletions electron/main/llm/llmSessionHandlers.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ipcMain, IpcMainInvokeEvent } from "electron";
import { LlamaCPPModelLoader, LlamaCPPSessionService } from "./models/LlamaCpp"; // Assuming SessionService is in the same directory
import { LlamaCPPSessionService } from "./models/LlamaCpp"; // Assuming SessionService is in the same directory
import { ISessionService } from "./Types";
import { OpenAIModel, OpenAIModelSessionService } from "./models/OpenAI";
import { AIModelConfig, StoreKeys, StoreSchema } from "../Config/storeConfig";
Expand All @@ -13,13 +13,13 @@ const sessions: { [sessionId: string]: ISessionService } = {};
export const registerLLMSessionHandlers = (store: Store<StoreSchema>) => {
const openAIAPIKey: string = store.get(StoreKeys.UserOpenAIAPIKey);

const llamaCPPModelLoader = new LlamaCPPModelLoader();
llamaCPPModelLoader.loadModel();
// const llamaCPPModelLoader = new LlamaCPPModelLoader();
// llamaCPPModelLoader.loadModel();
// const gpt4SessionService = new GPT4SessionService(gpt4Model, webContents);
// await gpt4SessionService.init();

ipcMain.handle(
"getOrCreateSession",
"get-or-create-session",
async (event: IpcMainInvokeEvent, sessionId: string) => {
if (sessions[sessionId]) {
return sessionId;
Expand All @@ -37,7 +37,7 @@ export const registerLLMSessionHandlers = (store: Store<StoreSchema>) => {
sessions[sessionId] = sessionService;
return sessionId;
} else {
const sessionService = new LlamaCPPSessionService(llamaCPPModelLoader);
const sessionService = new LlamaCPPSessionService();
sessions[sessionId] = sessionService;
return sessionId;
}
Expand Down
80 changes: 44 additions & 36 deletions electron/main/llm/models/LlamaCpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,58 @@ import path from "path";
import os from "os";
import { IModel, ISendFunctionImplementer, ISessionService } from "../Types";

export class LlamaCPPModelLoader implements IModel {
private model: any;
export class LlamaCPPSessionService implements ISessionService {
private session: any;
public context: any;
private model: any; // Model instance

async loadModel(): Promise<void> {
// Load model logic
this.model = await import("node-llama-cpp").then((nodeLLamaCpp: any) => {
return new nodeLLamaCpp.LlamaModel({
modelPath: path.join(
os.homedir(),
"Downloads",
"tinyllama-2-1b-miniguanaco.Q2_K.gguf"
// "mistral-7b-v0.1.Q4_K_M.gguf"
),
gpuLayers: 0,
});
});
constructor() {
this.init();
}

public async getModel(): Promise<any> {
if (!this.model) {
throw new Error("Model not initialized");
}
return this.model;
private async loadModel(): Promise<void> {
// Load model logic - similar to what was in LlamaCPPModelLoader
const nodeLLamaCpp = await import("node-llama-cpp");
this.model = new nodeLLamaCpp.LlamaModel({
modelPath: path.join(
os.homedir(),
"Downloads",
"tinyllama-2-1b-miniguanaco.Q2_K.gguf"
// "mistral-7b-v0.1.Q4_K_M.gguf"
),
gpuLayers: 0,
});

// this.model = await import("node-llama-cpp").then((nodeLLamaCpp: any) => {
// return new nodeLLamaCpp.LlamaModel({
// modelPath: path.join(
// os.homedir(),
// "Downloads",
// "tinyllama-2-1b-miniguanaco.Q2_K.gguf"
// // "mistral-7b-v0.1.Q4_K_M.gguf"
// ),
// gpuLayers: 0,
// });
// });
}

async unloadModel(): Promise<void> {
private async unloadModel(): Promise<void> {
// Unload model logic
this.model = null;
}

isModelLoaded(): boolean {
private isModelLoaded(): boolean {
return !!this.model;
}
}

export class LlamaCPPSessionService implements ISessionService {
private session: any;
public context: any;
private modelLoader: LlamaCPPModelLoader;

constructor(modelLoader: LlamaCPPModelLoader) {
this.modelLoader = modelLoader;
this.init();
}
async init() {
// eslint-disable-next-line @typescript-eslint/no-var-requires
await this.loadModel();
if (!this.isModelLoaded()) {
throw new Error("Model not loaded");
}

import("node-llama-cpp").then(async (nodeLLamaCpp: any) => {
const model = await this.modelLoader.getModel();
this.context = new nodeLLamaCpp.LlamaContext({ model });
this.context = new nodeLLamaCpp.LlamaContext({ model: this.model });
this.session = new nodeLLamaCpp.LlamaChatSession({
context: this.context,
});
Expand All @@ -71,7 +74,12 @@ export class LlamaCPPSessionService implements ISessionService {
topP: 0.02,
onToken: (chunk: any[]) => {
const decodedChunk = this.context.decode(chunk);
sendFunctionImplementer.send("tokenStream", decodedChunk);
console.log("decoded chunk: ", decodedChunk);
sendFunctionImplementer.send("tokenStream", {
messageType: "success",
message: decodedChunk,
});
// sendFunctionImplementer.send("tokenStream", decodedChunk);
},
});
}
Expand Down
11 changes: 0 additions & 11 deletions electron/main/llm/models/OpenAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,3 @@ export class OpenAIModelSessionService implements ISessionService {
}
}
}

// const fetchOpenAIModels = async (apiKey: string) => {
// try {
// const openai = new OpenAI({ apiKey });
// const modelsResponse = await openai.models.list();
// return modelsResponse.data;
// } catch (error) {
// console.error("Error fetching models from OpenAI:", error);
// return [];
// }
// };
2 changes: 1 addition & 1 deletion electron/preload/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ contextBridge.exposeInMainWorld("llm", {
// return await ipcRenderer.invoke("createSession", sessionId);
// },
getOrCreateSession: async (sessionId: any) => {
return await ipcRenderer.invoke("getOrCreateSession", sessionId);
return await ipcRenderer.invoke("get-or-create-session", sessionId);
},
getHello: async (sessionId: any) => {
return await ipcRenderer.invoke("getHello", sessionId);
Expand Down
29 changes: 15 additions & 14 deletions src/components/Chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ const ChatWithLLM: React.FC = () => {
const [currentBotMessage, setCurrentBotMessage] =
useState<ChatbotMessage | null>(null);

const initializeSession = async () => {
setLoading(true);
try {
console.log("Creating a new session...");
const newSessionId = await window.llm.getOrCreateSession(
"some_unique_session_id"
);
console.log("Created a new session with id:", newSessionId);
setSessionId(newSessionId);
} catch (error) {
console.error("Failed to create a new session:", error);
} finally {
setLoading(false);
}
};
useEffect(() => {
if (!sessionId) {
initializeSession();
Expand All @@ -36,20 +51,6 @@ const ChatWithLLM: React.FC = () => {
}
}, [sessionId]);

const initializeSession = async () => {
setLoading(true);
try {
const newSessionId = await window.llm.getOrCreateSession(
"some_unique_session_id"
);
setSessionId(newSessionId);
} catch (error) {
console.error("Failed to create a new session:", error);
} finally {
setLoading(false);
}
};

const handleSubmitNewMessage = async () => {
if (currentBotMessage) {
setMessages((prevMessages) => [
Expand Down
3 changes: 0 additions & 3 deletions src/components/Settings/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ const SettingsModal: React.FC<ModalProps> = ({ isOpen, onClose }) => {
<div className="mt-2">
<AIModelManager />
</div>
{/* <AIModelDropdown />
<ConfigureLLMComponent /> */}

<Button
className="bg-slate-700 mt-4 border-none h-10 hover:bg-slate-900 cursor-pointer w-[80px] text-center pt-0 pb-0 pr-2 pl-2"
onClick={handleSave}
Expand Down

0 comments on commit bff6362

Please sign in to comment.