From 28666f3160e979e5694f8c74ffd4853478f2c8e6 Mon Sep 17 00:00:00 2001 From: Shafil Alam Date: Thu, 7 Nov 2024 22:58:39 -0500 Subject: [PATCH] Fix major bugs with server --- src/ai.ts | 158 +++++++++++++++++++------------------- src/server.ts | 16 ++-- ui/components/options.tsx | 2 +- 3 files changed, 92 insertions(+), 84 deletions(-) diff --git a/src/ai.ts b/src/ai.ts index 0339047..a078e6b 100644 --- a/src/ai.ts +++ b/src/ai.ts @@ -86,95 +86,100 @@ export class OpenAIGen extends AIGen { * @throws Error if API call fails */ static async generate(log: (msg: string) => void, systemPrompt: string, userPrompt: string, apiKey?: string, options?: AIOptions): Promise { - const endpoint = options?.endpoint ?? this.DEFAULT_ENDPOINT; - const model = options?.model ?? this.DEFAULT_MODEL; + try { + const endpoint = options?.endpoint ?? this.DEFAULT_ENDPOINT; + const model = options?.model ?? this.DEFAULT_MODEL; - log(`Using OpenAI model: ${model}`); - log(`Calling OpenAI API with endpoint: ${endpoint}`); + log(`Using OpenAI model: ${model}`); + log(`Calling OpenAI API with endpoint: ${endpoint}`); - if (apiKey == "" || apiKey == undefined) { - log("[*] Warning: OpenAI API key is not set! Set via '--openaiAPIKey' flag or define 'OPENAI_API_KEY' environment variable."); - } + if (apiKey == "" || apiKey == undefined) { + log("[*] Warning: OpenAI API key is not set! Set via '--openaiAPIKey' flag or define 'OPENAI_API_KEY' environment variable."); + } - const client = new OpenAI({ - apiKey: apiKey, - baseURL: endpoint, - }); + const client = new OpenAI({ + apiKey: apiKey, + baseURL: endpoint, + }); - const messages = [ - { role: 'system', content: systemPrompt }, - { role: 'user', content: INITIAL_AI_PROMPT + userPrompt } - ]; + const messages = [ + { role: 'user', content: systemPrompt }, + { role: 'user', content: INITIAL_AI_PROMPT + userPrompt } + ]; - const response = await client.chat.completions.create({ - model: model, - messages: messages as ChatCompletionMessageParam[], - stream: true, - response_format: { "type": "json_object" } - }); + const response = await client.chat.completions.create({ + max_tokens: 1024, + model: model, + messages: messages as OpenAI.ChatCompletionMessageParam[], + response_format: { "type": "json_object" } + }); - let aiResponse = ''; - let videoType = ''; + let aiResponse = ''; + let videoType = ''; - for await (const chunk of response) { - videoType += chunk.choices[0]?.delta?.content; - log(`AI Response chunk -> ${chunk.choices[0]?.delta?.content?.trim()}`); - } + videoType = response.choices[0].message.content ?? ""; - videoType = videoType.trim(); + messages.push({ role: 'assistant', content: videoType }); - // Get AI prompts for video type - // Check if AI string matches any type in VideoGenType enum - const videoTypeValues = Object.values(VideoGenType); - let matchedType = videoTypeValues.find((val) => val.toLowerCase() === videoType.toLowerCase()); + videoType = videoType.trim(); - if (!matchedType) { - log(`[*] Invalid video type (defaulting to topic): '${videoType}'`); - matchedType = VideoGenType.TopicVideo; - } + // Get video type from AI response + // JSON parse "type" and check if it matches any type in VideoGenType enum + const videoTypeValues = Object.values(VideoGenType); + const videoTypeJson = JSON.parse(videoType)["type"]; + let matchedType = videoTypeValues.find((val) => val.toLowerCase() === videoTypeJson.toLowerCase()); - log(`(OpenAI ${model}) AI said video type is '${matchedType}'`); + if (!matchedType) { + log(`[*] Invalid video type (defaulting to topic): '${videoType}'`); + matchedType = VideoGenType.TopicVideo; + } - const videoGenType = matchedType as VideoGenType; - const aiPrompt = convertVideoTypeToPrompt(videoGenType); + log(`(OpenAI ${model}) AI said video type is '${matchedType}'`); - // Get each prompt from each field and add to JSON - const videoJson: any = {}; + const videoGenType = matchedType as VideoGenType; - videoJson["type"] = videoGenType; + // Get AI prompts for video type + const aiPrompt = convertVideoTypeToPrompt(videoGenType); - for (const [key, value] of Object.entries(aiPrompt)) { - const prompt = value; + // Get each prompt from each field and add to JSON + const videoJson: any = {}; - log(`(OpenAI ${model}) Will ask AI for field '${key}' with prompt '${prompt}'`); + videoJson["type"] = videoGenType; - messages.push({ role: 'user', content: prompt }); - const response = await client.chat.completions.create({ - model: model, - messages: messages as ChatCompletionMessageParam[], - stream: true, - }); + for (const [key, value] of Object.entries(aiPrompt)) { + const prompt = value; - let res = ''; - for await (const chunk of response) { - res += chunk.choices[0]?.delta?.content; - log(`AI Response chunk -> ${chunk.choices[0]?.delta?.content?.trim()}`); - } + log(`(OpenAI ${model}) Will ask AI for field '${key}' with prompt '${prompt}'`); - // Try to parse JSON response and validate it - try { - const jsonRes = JSON.parse(res.trim()); - videoJson[key] = jsonRes[key] ?? jsonRes; - } catch (error: any) { - console.info(`(Google AI ${model}) Error parsing JSON response: ${error.message}`); - } + messages.push({ role: 'user', content: prompt }); + const response = await client.chat.completions.create({ + max_tokens: 1024, + model: model, + messages: messages as OpenAI.ChatCompletionMessageParam[], + }); - log(`(OpenAI ${model}) AI said for field '${key}' is '${res}'`); - } + let res = response.choices[0].message.content ?? ""; - aiResponse = JSON.stringify(videoJson, null, 2); + messages.push({ role: 'assistant', content: res }); - return aiResponse; + // Try to parse JSON response and validate it + try { + const jsonRes = JSON.parse(res.trim()); + videoJson[key] = jsonRes[key] ?? jsonRes; + } catch (error: any) { + console.info(`(OpenAI ${model}) Error parsing JSON response: ${error.message}`); + } + + log(`(OpenAI ${model}) AI said for field '${key}' is '${res}'`); + } + + aiResponse = JSON.stringify(videoJson, null, 2); + + return aiResponse; + } catch (error: any) { + console.log("Error while calling OpenAI API: " + error.message); + } + return ""; // Return empty string if error } /** @@ -385,6 +390,7 @@ export class AnthropicAIGen extends AIGen { const videoGenType = matchedType as VideoGenType; + // Get AI prompts for video type const aiPrompt = convertVideoTypeToPrompt(videoGenType); // Get each prompt from each field and add to JSON @@ -423,9 +429,9 @@ export class AnthropicAIGen extends AIGen { const jsonRes = JSON.parse(res.trim()); videoJson[key] = jsonRes[key] ?? jsonRes; } catch (error: any) { - console.info(`(Google AI ${model}) Error parsing JSON response: ${error.message}`); + console.info(`(Anthropic AI ${model}) Error parsing JSON response: ${error.message}`); } - + log(`(Anthropic ${model}) AI said for field '${key}' is '${res}'`); } @@ -482,6 +488,8 @@ export class OllamaAIGen extends AIGen { log(`AI Response chunk for 'type' -> ${msgChunk.trim()}`); } + messages.push({ role: 'assistant', content: videoType }); + videoType = videoType.trim(); // Get AI prompts for video type @@ -510,15 +518,9 @@ export class OllamaAIGen extends AIGen { log(`(Ollama ${model}) Will ask AI for field '${key}' with prompt '${prompt}'`); - // messages.push({ role: 'user', content: prompt }); + messages.push({ role: 'user', content: prompt }); - // Low param Local LLMs lose context after a few turns, so we need to reset the context - const innerMessages = [ - { role: 'user', content: systemPrompt }, - { role: 'user', content: `${prompt}\n User comment: ${userPrompt}` } - ]; - - const response = await ollama.chat({ model: model, messages: innerMessages, stream: true, format: 'json' }); + const response = await ollama.chat({ model: model, messages: messages, stream: true, format: 'json' }); let res = ''; for await (const part of response) { @@ -527,6 +529,8 @@ export class OllamaAIGen extends AIGen { log(`AI Response chunk for '${key}' -> ${msgChunk.trim()}`); } + messages.push({ role: 'assistant', content: res }); + // Try to parse JSON response and validate it try { const jsonRes = JSON.parse(res); diff --git a/src/server.ts b/src/server.ts index 2a66b92..b4e309e 100644 --- a/src/server.ts +++ b/src/server.ts @@ -218,7 +218,8 @@ export async function runAPIServer() { }, aiAPIKey, { - endpoint: data.openAIEndpoint + endpoint: data.openAIEndpoint, + model: data.aiModel } ); @@ -291,7 +292,7 @@ export async function runAPIServer() { }, ); - console.info("Stating live log stream to client..."); + console.info("Starting live log stream to client..."); task.on('log', (log: string) => { // Send log to client in JSON @@ -407,6 +408,9 @@ export async function runAPIServer() { // Get AI type const aiTypeStr = req.query.type as string; + // Get openAI endpoint if provided + const openAIEndpoint = req.query.endpoint as string; + // Check if empty if (!aiTypeStr) { res.status(400).json({ @@ -443,17 +447,17 @@ export async function runAPIServer() { models = await OllamaAIGen.getModels(); break; case AIGenType.OpenAIGen: - apiKey = process.env[AIGenType.OpenAIGen]; + apiKey = process.env[AIAPIEnv.OpenAIGen]; if (!apiKey) return errorIfNoAPIKey(); - models = await OpenAIGen.getModels(); + models = await OpenAIGen.getModels(apiKey, { endpoint: openAIEndpoint }); break; case AIGenType.GoogleAIGen: - apiKey = process.env[AIGenType.GoogleAIGen]; + apiKey = process.env[AIAPIEnv.GoogleAIGen]; if (!apiKey) return errorIfNoAPIKey(); models = await GoogleAIGen.getModels(); break; case AIGenType.AnthropicAIGen: - apiKey = process.env[AIGenType.AnthropicAIGen]; + apiKey = process.env[AIAPIEnv.AnthropicAIGen]; if (!apiKey) return errorIfNoAPIKey(); models = await AnthropicAIGen.getModels(); break; diff --git a/ui/components/options.tsx b/ui/components/options.tsx index 7e2e380..7a9622a 100644 --- a/ui/components/options.tsx +++ b/ui/components/options.tsx @@ -51,7 +51,7 @@ const config = { { "name": "ElevenLabs", "description": "ElevenLabs advanced high-quality TTS (API key required)", - "type": "ElevenLabsTTS", + "type": "ElevenLabs", }, { "name": "Neets.ai",