(
}
event.preventDefault();
-
sendMessage?.(event);
}
}}
diff --git a/app/components/chat/Chat.client.tsx b/app/components/chat/Chat.client.tsx
index 458bd8364..16d07d5e9 100644
--- a/app/components/chat/Chat.client.tsx
+++ b/app/components/chat/Chat.client.tsx
@@ -1,240 +1,320 @@
// @ts-nocheck
// Preventing TS checks with files presented in the video for a better presentation.
-import { useStore } from '@nanostores/react';
-import type { Message } from 'ai';
-import { useChat } from 'ai/react';
-import { useAnimate } from 'framer-motion';
-import { memo, useEffect, useRef, useState } from 'react';
-import { cssTransition, toast, ToastContainer } from 'react-toastify';
-import { useMessageParser, usePromptEnhancer, useShortcuts, useSnapScroll } from '~/lib/hooks';
-import { useChatHistory } from '~/lib/persistence';
-import { chatStore } from '~/lib/stores/chat';
-import { workbenchStore } from '~/lib/stores/workbench';
-import { fileModificationsToHTML } from '~/utils/diff';
-import { DEFAULT_MODEL } from '~/utils/constants';
-import { cubicEasingFn } from '~/utils/easings';
-import { createScopedLogger, renderLogger } from '~/utils/logger';
-import { BaseChat } from './BaseChat';
+import { useStore } from "@nanostores/react";
+import type { Message } from "ai";
+import { useChat } from "ai/react";
+import { useAnimate } from "framer-motion";
+import { memo, useEffect, useRef, useState } from "react";
+import { cssTransition, toast, ToastContainer } from "react-toastify";
+import {
+ useMessageParser,
+ usePromptEnhancer,
+ useShortcuts,
+ useSnapScroll,
+} from "~/lib/hooks";
+import { useChatHistory } from "~/lib/persistence";
+import { chatStore } from "~/lib/stores/chat";
+import { workbenchStore } from "~/lib/stores/workbench";
+import { fileModificationsToHTML } from "~/utils/diff";
+import {
+ DEFAULT_MODEL,
+ DEFAULT_PROVIDER,
+ getModelListInfo,
+ initializeModelList,
+ isInitialized,
+ MODEL_LIST,
+ PROVIDER_LIST,
+} from "~/utils/constants";
+import { cubicEasingFn } from "~/utils/easings";
+import { createScopedLogger, renderLogger } from "~/utils/logger";
+import { BaseChat } from "./BaseChat";
const toastAnimation = cssTransition({
- enter: 'animated fadeInRight',
- exit: 'animated fadeOutRight',
+ enter: "animated fadeInRight",
+ exit: "animated fadeOutRight",
});
-const logger = createScopedLogger('Chat');
+const logger = createScopedLogger("Chat");
export function Chat() {
- renderLogger.trace('Chat');
-
- const { ready, initialMessages, storeMessageHistory } = useChatHistory();
-
- return (
- <>
- {ready &&
}
-
{
- return (
-
- );
- }}
- icon={({ type }) => {
- /**
- * @todo Handle more types if we need them. This may require extra color palettes.
- */
- switch (type) {
- case 'success': {
- return ;
- }
- case 'error': {
- return ;
- }
- }
-
- return undefined;
- }}
- position="bottom-right"
- pauseOnFocusLoss
- transition={toastAnimation}
- />
- >
- );
+ renderLogger.trace("Chat");
+
+ const { ready, initialMessages, storeMessageHistory } = useChatHistory();
+
+ return (
+ <>
+ {ready && (
+
+ )}
+ {
+ return (
+
+ );
+ }}
+ icon={({ type }) => {
+ /**
+ * @todo Handle more types if we need them. This may require extra color palettes.
+ */
+ switch (type) {
+ case "success": {
+ return (
+
+ );
+ }
+ case "error": {
+ return (
+
+ );
+ }
+ }
+
+ return undefined;
+ }}
+ position="bottom-right"
+ pauseOnFocusLoss
+ transition={toastAnimation}
+ />
+ >
+ );
}
interface ChatProps {
- initialMessages: Message[];
- storeMessageHistory: (messages: Message[]) => Promise;
+ initialMessages: Message[];
+ storeMessageHistory: (messages: Message[]) => Promise;
}
-export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProps) => {
- useShortcuts();
-
- const textareaRef = useRef(null);
-
- const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
- const [model, setModel] = useState(DEFAULT_MODEL);
-
- const { showChat } = useStore(chatStore);
-
- const [animationScope, animate] = useAnimate();
-
- const { messages, isLoading, input, handleInputChange, setInput, stop, append } = useChat({
- api: '/api/chat',
- onError: (error) => {
- logger.error('Request failed\n\n', error);
- toast.error('There was an error processing your request');
- },
- onFinish: () => {
- logger.debug('Finished streaming');
- },
- initialMessages,
- });
-
- const { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer } = usePromptEnhancer();
- const { parsedMessages, parseMessages } = useMessageParser();
-
- const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;
-
- useEffect(() => {
- chatStore.setKey('started', initialMessages.length > 0);
- }, []);
-
- useEffect(() => {
- parseMessages(messages, isLoading);
-
- if (messages.length > initialMessages.length) {
- storeMessageHistory(messages).catch((error) => toast.error(error.message));
- }
- }, [messages, isLoading, parseMessages]);
-
- const scrollTextArea = () => {
- const textarea = textareaRef.current;
-
- if (textarea) {
- textarea.scrollTop = textarea.scrollHeight;
- }
- };
-
- const abort = () => {
- stop();
- chatStore.setKey('aborted', true);
- workbenchStore.abortAllActions();
- };
-
- useEffect(() => {
- const textarea = textareaRef.current;
-
- if (textarea) {
- textarea.style.height = 'auto';
-
- const scrollHeight = textarea.scrollHeight;
-
- textarea.style.height = `${Math.min(scrollHeight, TEXTAREA_MAX_HEIGHT)}px`;
- textarea.style.overflowY = scrollHeight > TEXTAREA_MAX_HEIGHT ? 'auto' : 'hidden';
- }
- }, [input, textareaRef]);
-
- const runAnimation = async () => {
- if (chatStarted) {
- return;
- }
-
- await Promise.all([
- animate('#examples', { opacity: 0, display: 'none' }, { duration: 0.1 }),
- animate('#intro', { opacity: 0, flex: 1 }, { duration: 0.2, ease: cubicEasingFn }),
- ]);
-
- chatStore.setKey('started', true);
-
- setChatStarted(true);
- };
-
- const sendMessage = async (_event: React.UIEvent, messageInput?: string) => {
- const _input = messageInput || input;
-
- if (_input.length === 0 || isLoading) {
- return;
- }
-
- /**
- * @note (delm) Usually saving files shouldn't take long but it may take longer if there
- * many unsaved files. In that case we need to block user input and show an indicator
- * of some kind so the user is aware that something is happening. But I consider the
- * happy case to be no unsaved files and I would expect users to save their changes
- * before they send another message.
- */
- await workbenchStore.saveAllFiles();
-
- const fileModifications = workbenchStore.getFileModifcations();
-
- chatStore.setKey('aborted', false);
-
- runAnimation();
-
- if (fileModifications !== undefined) {
- const diff = fileModificationsToHTML(fileModifications);
-
- /**
- * If we have file modifications we append a new user message manually since we have to prefix
- * the user input with the file modifications and we don't want the new user input to appear
- * in the prompt. Using `append` is almost the same as `handleSubmit` except that we have to
- * manually reset the input and we'd have to manually pass in file attachments. However, those
- * aren't relevant here.
- */
- append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` });
-
- /**
- * After sending a new message we reset all modifications since the model
- * should now be aware of all the changes.
- */
- workbenchStore.resetAllFileModifications();
- } else {
- append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` });
- }
-
- setInput('');
-
- resetEnhancer();
-
- textareaRef.current?.blur();
- };
-
- const [messageRef, scrollRef] = useSnapScroll();
-
- return (
- {
- if (message.role === 'user') {
- return message;
- }
-
- return {
- ...message,
- content: parsedMessages[i] || '',
- };
- })}
- enhancePrompt={() => {
- enhancePrompt(input, (input) => {
- setInput(input);
- scrollTextArea();
- });
- }}
- />
- );
-});
+export const ChatImpl = memo(
+ ({ initialMessages, storeMessageHistory }: ChatProps) => {
+ useShortcuts();
+
+ const textareaRef = useRef(null);
+
+ const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
+ const [model, setModel] = useState(DEFAULT_MODEL);
+ const [provider, setProvider] = useState(DEFAULT_PROVIDER);
+ const list = getModelListInfo();
+ const [modelList, setModelList] = useState(list.modelList);
+ const [providerList, setProviderList] = useState(list.providerList);
+ // TODO: Add API key
+ const [api_key, setApiKey] = useState("");
+ const initialize = async () => {
+ if (!isInitialized) {
+ const { modelList, providerList } = await initializeModelList();
+ setModelList(modelList);
+ setProviderList(providerList);
+ }
+ };
+ initialize();
+
+ const { showChat } = useStore(chatStore);
+
+ const [animationScope, animate] = useAnimate();
+
+ const {
+ messages,
+ isLoading,
+ input,
+ handleInputChange,
+ setInput,
+ stop,
+ append,
+ } = useChat({
+ api: "/api/chat",
+ onError: (error) => {
+ logger.error("Request failed\n\n", error);
+ toast.error("There was an error processing your request");
+ },
+ onFinish: () => {
+ logger.debug("Finished streaming");
+ },
+ initialMessages,
+ });
+
+ const { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer } =
+ usePromptEnhancer();
+ const { parsedMessages, parseMessages } = useMessageParser();
+
+ const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;
+
+ useEffect(() => {
+ chatStore.setKey("started", initialMessages.length > 0);
+ }, []);
+
+ useEffect(() => {
+ parseMessages(messages, isLoading);
+
+ if (messages.length > initialMessages.length) {
+ storeMessageHistory(messages).catch((error) =>
+ toast.error(error.message),
+ );
+ }
+ }, [messages, isLoading, parseMessages]);
+
+ const scrollTextArea = () => {
+ const textarea = textareaRef.current;
+
+ if (textarea) {
+ textarea.scrollTop = textarea.scrollHeight;
+ }
+ };
+
+ const abort = () => {
+ stop();
+ chatStore.setKey("aborted", true);
+ workbenchStore.abortAllActions();
+ };
+
+ useEffect(() => {
+ const textarea = textareaRef.current;
+
+ if (textarea) {
+ textarea.style.height = "auto";
+
+ const scrollHeight = textarea.scrollHeight;
+
+ textarea.style.height = `${Math.min(scrollHeight, TEXTAREA_MAX_HEIGHT)}px`;
+ textarea.style.overflowY =
+ scrollHeight > TEXTAREA_MAX_HEIGHT ? "auto" : "hidden";
+ }
+ }, [input, textareaRef]);
+
+ const runAnimation = async () => {
+ if (chatStarted) {
+ return;
+ }
+
+ await Promise.all([
+ animate(
+ "#examples",
+ { opacity: 0, display: "none" },
+ { duration: 0.1 },
+ ),
+ animate(
+ "#intro",
+ { opacity: 0, flex: 1 },
+ { duration: 0.2, ease: cubicEasingFn },
+ ),
+ ]);
+
+ chatStore.setKey("started", true);
+
+ setChatStarted(true);
+ };
+
+ const sendMessage = async (
+ _event: React.UIEvent,
+ messageInput?: string,
+ ) => {
+ const _input = messageInput || input;
+
+ if (_input.length === 0 || isLoading) {
+ return;
+ }
+
+ /**
+ * @note (delm) Usually saving files shouldn't take long but it may take longer if there
+ * many unsaved files. In that case we need to block user input and show an indicator
+ * of some kind so the user is aware that something is happening. But I consider the
+ * happy case to be no unsaved files and I would expect users to save their changes
+ * before they send another message.
+ */
+ await workbenchStore.saveAllFiles();
+
+ const fileModifications = workbenchStore.getFileModifcations();
+
+ chatStore.setKey("aborted", false);
+
+ runAnimation();
+ const message = { role: "user", content: "" };
+ const body = {
+ model,
+ provider,
+ api_key,
+ };
+ if (fileModifications !== undefined) {
+ const diff = fileModificationsToHTML(fileModifications);
+
+ /**
+ * If we have file modifications we append a new user message manually since we have to prefix
+ * the user input with the file modifications and we don't want the new user input to appear
+ * in the prompt. Using `append` is almost the same as `handleSubmit` except that we have to
+ * manually reset the input and we'd have to manually pass in file attachments. However, those
+ * aren't relevant here.
+ */
+ message.content = `${diff}\n\n${_input}`;
+
+ /**
+ * After sending a new message we reset all modifications since the model
+ * should now be aware of all the changes.
+ */
+ workbenchStore.resetAllFileModifications();
+ } else {
+ message.content = _input;
+ }
+ append(message, {
+ body,
+ });
+ setInput("");
+
+ resetEnhancer();
+
+ textareaRef.current?.blur();
+ };
+
+ const [messageRef, scrollRef] = useSnapScroll();
+
+ return (
+ {
+ if (message.role === "user") {
+ return message;
+ }
+
+ return {
+ ...message,
+ content: parsedMessages[i] || "",
+ };
+ })}
+ enhancePrompt={() => {
+ enhancePrompt(
+ {
+ input,
+ model,
+ provider,
+ api_key,
+ },
+ (input) => {
+ setInput(input);
+ scrollTextArea();
+ },
+ );
+ }}
+ />
+ );
+ },
+);
diff --git a/app/entry.server.tsx b/app/entry.server.tsx
index be2b42bf0..0217634a1 100644
--- a/app/entry.server.tsx
+++ b/app/entry.server.tsx
@@ -5,7 +5,6 @@ import { renderToReadableStream } from 'react-dom/server';
import { renderHeadToString } from 'remix-island';
import { Head } from './root';
import { themeStore } from '~/lib/stores/theme';
-import { initializeModelList } from '~/utils/constants';
export default async function handleRequest(
request: Request,
@@ -14,7 +13,7 @@ export default async function handleRequest(
remixContext: EntryContext,
_loadContext: AppLoadContext,
) {
- await initializeModelList();
+
const readable = await renderToReadableStream(, {
signal: request.signal,
diff --git a/app/lib/.server/llm/model.ts b/app/lib/.server/llm/model.ts
index 390d57aeb..8338ab387 100644
--- a/app/lib/.server/llm/model.ts
+++ b/app/lib/.server/llm/model.ts
@@ -6,7 +6,6 @@ import { createOpenAI } from '@ai-sdk/openai';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { ollama } from 'ollama-ai-provider';
import { createOpenRouter } from "@openrouter/ai-sdk-provider";
-import { mistral } from '@ai-sdk/mistral';
import { createMistral } from '@ai-sdk/mistral';
export function getAnthropicModel(apiKey: string, model: string) {
@@ -80,27 +79,27 @@ export function getOpenRouterModel(apiKey: string, model: string) {
return openRouter.chat(model);
}
-export function getModel(provider: string, model: string, env: Env) {
- const apiKey = getAPIKey(env, provider);
+export function getModel(provider: string, model: string,apiKey:string, env: Env) {
+ const _apiKey = apiKey || getAPIKey(env, provider);
const baseURL = getBaseURL(env, provider);
switch (provider) {
case 'Anthropic':
- return getAnthropicModel(apiKey, model);
+ return getAnthropicModel(_apiKey, model);
case 'OpenAI':
- return getOpenAIModel(apiKey, model);
+ return getOpenAIModel(_apiKey, model);
case 'Groq':
- return getGroqModel(apiKey, model);
+ return getGroqModel(_apiKey, model);
case 'OpenRouter':
- return getOpenRouterModel(apiKey, model);
+ return getOpenRouterModel(_apiKey, model);
case 'Google':
- return getGoogleModel(apiKey, model)
+ return getGoogleModel(_apiKey, model)
case 'OpenAILike':
- return getOpenAILikeModel(baseURL,apiKey, model);
+ return getOpenAILikeModel(baseURL,_apiKey, model);
case 'Deepseek':
- return getDeepseekModel(apiKey, model)
+ return getDeepseekModel(_apiKey, model)
case 'Mistral':
- return getMistralModel(apiKey, model);
+ return getMistralModel(_apiKey, model);
default:
return getOllamaModel(baseURL, model);
}
diff --git a/app/lib/.server/llm/stream-text.ts b/app/lib/.server/llm/stream-text.ts
index de3d5bfa8..f1e057f41 100644
--- a/app/lib/.server/llm/stream-text.ts
+++ b/app/lib/.server/llm/stream-text.ts
@@ -1,66 +1,64 @@
// @ts-nocheck
// Preventing TS checks with files presented in the video for a better presentation.
-import { streamText as _streamText, convertToCoreMessages } from 'ai';
-import { getModel } from '~/lib/.server/llm/model';
-import { MAX_TOKENS } from './constants';
-import { getSystemPrompt } from './prompts';
-import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants';
+import { streamText as _streamText, convertToCoreMessages } from "ai";
+import { getModel } from "~/lib/.server/llm/model";
+import { MAX_TOKENS } from "./constants";
+import { getSystemPrompt } from "./prompts";
+import { DEFAULT_MODEL, DEFAULT_PROVIDER, hasModel } from "~/utils/constants";
+import type { ChatRequest } from "~/routes/api.chat";
interface ToolResult {
- toolCallId: string;
- toolName: Name;
- args: Args;
- result: Result;
+ toolCallId: string;
+ toolName: Name;
+ args: Args;
+ result: Result;
}
interface Message {
- role: 'user' | 'assistant';
- content: string;
- toolInvocations?: ToolResult[];
- model?: string;
+ role: "user" | "assistant";
+ content: string;
+ toolInvocations?: ToolResult[];
+ model?: string;
}
export type Messages = Message[];
-export type StreamingOptions = Omit[0], 'model'>;
+export type StreamingOptions = Omit[0], "model">;
-function extractModelFromMessage(message: Message): { model: string; content: string } {
- const modelRegex = /^\[Model: (.*?)\]\n\n/;
- const match = message.content.match(modelRegex);
+// function extractModelFromMessage(message: Message): { model: string; content: string } {
+// const modelRegex = /^\[Model: (.*?)Provider: (.*?)\]\n\n/;
+// const match = message.content.match(modelRegex);
+//
+// if (!match) {
+// return { model: DEFAULT_MODEL, content: message.content,provider: DEFAULT_PROVIDER };
+// }
+// const [_,model,provider] = match;
+// const content = message.content.replace(modelRegex, '');
+// return { model, content ,provider};
+// // Default model if not specified
+//
+// }
- if (match) {
- const model = match[1];
- const content = message.content.replace(modelRegex, '');
- return { model, content };
- }
+export function streamText(
+ chatRequest: ChatRequest,
+ env: Env,
+ options?: StreamingOptions,
+) {
+ const { messages, model, api_key, provider } = chatRequest;
- // Default model if not specified
- return { model: DEFAULT_MODEL, content: message.content };
-}
-
-export function streamText(messages: Messages, env: Env, options?: StreamingOptions) {
- let currentModel = DEFAULT_MODEL;
- const processedMessages = messages.map((message) => {
- if (message.role === 'user') {
- const { model, content } = extractModelFromMessage(message);
- if (model && MODEL_LIST.find((m) => m.name === model)) {
- currentModel = model; // Update the current model
- }
- return { ...message, content };
- }
- return message;
- });
-
- const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER;
+ const _hasModel = hasModel(model, provider);
+ let currentModel = _hasModel ? model : DEFAULT_MODEL;
+ let currentProvider = _hasModel ? provider : DEFAULT_PROVIDER;
- return _streamText({
- model: getModel(provider, currentModel, env),
- system: getSystemPrompt(),
- maxTokens: MAX_TOKENS,
- // headers: {
- // 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
- // },
- messages: convertToCoreMessages(processedMessages),
- ...options,
- });
+ const coreMessages = convertToCoreMessages(messages);
+ return _streamText({
+ model: getModel(currentProvider, currentModel, api_key, env),
+ system: getSystemPrompt(),
+ maxTokens: MAX_TOKENS,
+ // headers: {
+ // 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
+ // },
+ messages: coreMessages,
+ ...options,
+ });
}
diff --git a/app/lib/hooks/usePromptEnhancer.ts b/app/lib/hooks/usePromptEnhancer.ts
index f376cc0cd..780989a84 100644
--- a/app/lib/hooks/usePromptEnhancer.ts
+++ b/app/lib/hooks/usePromptEnhancer.ts
@@ -1,71 +1,83 @@
-import { useState } from 'react';
-import { createScopedLogger } from '~/utils/logger';
+import { useState } from "react";
+import { createScopedLogger } from "~/utils/logger";
-const logger = createScopedLogger('usePromptEnhancement');
+const logger = createScopedLogger("usePromptEnhancement");
export function usePromptEnhancer() {
- const [enhancingPrompt, setEnhancingPrompt] = useState(false);
- const [promptEnhanced, setPromptEnhanced] = useState(false);
-
- const resetEnhancer = () => {
- setEnhancingPrompt(false);
- setPromptEnhanced(false);
- };
-
- const enhancePrompt = async (input: string, setInput: (value: string) => void) => {
- setEnhancingPrompt(true);
- setPromptEnhanced(false);
-
- const response = await fetch('/api/enhancer', {
- method: 'POST',
- body: JSON.stringify({
- message: input,
- }),
- });
-
- const reader = response.body?.getReader();
-
- const originalInput = input;
-
- if (reader) {
- const decoder = new TextDecoder();
-
- let _input = '';
- let _error;
-
- try {
- setInput('');
-
- while (true) {
- const { value, done } = await reader.read();
-
- if (done) {
- break;
- }
-
- _input += decoder.decode(value);
-
- logger.trace('Set input', _input);
-
- setInput(_input);
- }
- } catch (error) {
- _error = error;
- setInput(originalInput);
- } finally {
- if (_error) {
- logger.error(_error);
- }
-
- setEnhancingPrompt(false);
- setPromptEnhanced(true);
-
- setTimeout(() => {
- setInput(_input);
- });
- }
- }
- };
-
- return { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer };
+ const [enhancingPrompt, setEnhancingPrompt] = useState(false);
+ const [promptEnhanced, setPromptEnhanced] = useState(false);
+
+ const resetEnhancer = () => {
+ setEnhancingPrompt(false);
+ setPromptEnhanced(false);
+ };
+ type EnhancePrompt = {
+ input: string;
+ model: string;
+ provider: string;
+ api_key: string;
+ };
+
+ const enhancePrompt = async (
+ { input, model, provider, api_key }: EnhancePrompt,
+ setInput: (value: string) => void,
+ ) => {
+ setEnhancingPrompt(true);
+ setPromptEnhanced(false);
+
+ const response = await fetch("/api/enhancer", {
+ method: "POST",
+ body: JSON.stringify({
+ message: input,
+ model,
+ provider,
+ api_key,
+ }),
+ });
+
+ const reader = response.body?.getReader();
+
+ const originalInput = input;
+
+ if (reader) {
+ const decoder = new TextDecoder();
+
+ let _input = "";
+ let _error;
+
+ try {
+ setInput("");
+
+ while (true) {
+ const { value, done } = await reader.read();
+
+ if (done) {
+ break;
+ }
+
+ _input += decoder.decode(value);
+
+ logger.trace("Set input", _input);
+
+ setInput(_input);
+ }
+ } catch (error) {
+ _error = error;
+ setInput(originalInput);
+ } finally {
+ if (_error) {
+ logger.error(_error);
+ }
+
+ setEnhancingPrompt(false);
+ setPromptEnhanced(true);
+
+ setTimeout(() => {
+ setInput(_input);
+ });
+ }
+ }
+ };
+
+ return { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer };
}
diff --git a/app/routes/_index.tsx b/app/routes/_index.tsx
index 86d73409c..7f197908a 100644
--- a/app/routes/_index.tsx
+++ b/app/routes/_index.tsx
@@ -3,6 +3,8 @@ import { ClientOnly } from 'remix-utils/client-only';
import { BaseChat } from '~/components/chat/BaseChat';
import { Chat } from '~/components/chat/Chat.client';
import { Header } from '~/components/header/Header';
+import { useState } from 'react';
+import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST } from '~/utils/constants';
export const meta: MetaFunction = () => {
return [{ title: 'Bolt' }, { name: 'description', content: 'Talk with Bolt, an AI assistant from StackBlitz' }];
@@ -11,10 +13,14 @@ export const meta: MetaFunction = () => {
export const loader = () => json({});
export default function Index() {
+ const [model, setModel] = useState(DEFAULT_MODEL);
+ const [provider, setProvider] = useState(DEFAULT_PROVIDER);
+ const [modelList, setModelList] = useState(MODEL_LIST);
+ const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]);
return (
- }>{() => }
+ }>{() => }
);
}
diff --git a/app/routes/api.chat.ts b/app/routes/api.chat.ts
index c455c193d..69db9a912 100644
--- a/app/routes/api.chat.ts
+++ b/app/routes/api.chat.ts
@@ -1,61 +1,79 @@
// @ts-nocheck
// Preventing TS checks with files presented in the video for a better presentation.
-import { type ActionFunctionArgs } from '@remix-run/cloudflare';
-import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
-import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts';
-import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text';
-import SwitchableStream from '~/lib/.server/llm/switchable-stream';
+import { type ActionFunctionArgs } from "@remix-run/cloudflare";
+import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from "~/lib/.server/llm/constants";
+import { CONTINUE_PROMPT } from "~/lib/.server/llm/prompts";
+import {
+ streamText,
+ type Messages,
+ type StreamingOptions,
+} from "~/lib/.server/llm/stream-text";
+import SwitchableStream from "~/lib/.server/llm/switchable-stream";
export async function action(args: ActionFunctionArgs) {
- return chatAction(args);
+ return chatAction(args);
}
+export type ChatRequest = {
+ messages: Messages;
+ model: string;
+ provider: string;
+ api_key: string;
+};
async function chatAction({ context, request }: ActionFunctionArgs) {
- const { messages } = await request.json<{ messages: Messages }>();
+ const chatRequest = await request.json();
+ const stream = new SwitchableStream();
+ try {
+ const options: StreamingOptions = {
+ toolChoice: "none",
+ onFinish: async ({ text: content, finishReason }) => {
+ if (finishReason !== "length") {
+ return stream.close();
+ }
- const stream = new SwitchableStream();
+ if (stream.switches >= MAX_RESPONSE_SEGMENTS) {
+ throw Error("Cannot continue message: Maximum segments reached");
+ }
- try {
- const options: StreamingOptions = {
- toolChoice: 'none',
- onFinish: async ({ text: content, finishReason }) => {
- if (finishReason !== 'length') {
- return stream.close();
- }
+ const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches;
- if (stream.switches >= MAX_RESPONSE_SEGMENTS) {
- throw Error('Cannot continue message: Maximum segments reached');
- }
+ console.log(
+ `Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`,
+ );
- const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches;
+ messages.push({ role: "assistant", content });
+ messages.push({ role: "user", content: CONTINUE_PROMPT });
- console.log(`Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`);
+ const result = await streamText(
+ chatRequest,
+ context.cloudflare.env,
+ options,
+ );
- messages.push({ role: 'assistant', content });
- messages.push({ role: 'user', content: CONTINUE_PROMPT });
+ return stream.switchSource(result.toAIStream());
+ },
+ };
- const result = await streamText(messages, context.cloudflare.env, options);
+ const result = await streamText(
+ chatRequest,
+ context.cloudflare.env,
+ options,
+ );
- return stream.switchSource(result.toAIStream());
- },
- };
+ stream.switchSource(result.toAIStream());
- const result = await streamText(messages, context.cloudflare.env, options);
+ return new Response(stream.readable, {
+ status: 200,
+ headers: {
+ contentType: "text/plain; charset=utf-8",
+ },
+ });
+ } catch (error) {
+ console.log(error);
- stream.switchSource(result.toAIStream());
-
- return new Response(stream.readable, {
- status: 200,
- headers: {
- contentType: 'text/plain; charset=utf-8',
- },
- });
- } catch (error) {
- console.log(error);
-
- throw new Response(null, {
- status: 500,
- statusText: 'Internal Server Error',
- });
- }
+ throw new Response(null, {
+ status: 500,
+ statusText: "Internal Server Error",
+ });
+ }
}
diff --git a/app/routes/api.enhancer.ts b/app/routes/api.enhancer.ts
index 5c8175ca3..d555aac98 100644
--- a/app/routes/api.enhancer.ts
+++ b/app/routes/api.enhancer.ts
@@ -1,24 +1,32 @@
-import { type ActionFunctionArgs } from '@remix-run/cloudflare';
-import { StreamingTextResponse, parseStreamPart } from 'ai';
-import { streamText } from '~/lib/.server/llm/stream-text';
-import { stripIndents } from '~/utils/stripIndent';
+import { type ActionFunctionArgs } from "@remix-run/cloudflare";
+import { StreamingTextResponse, parseStreamPart } from "ai";
+import { type Messages, streamText } from "~/lib/.server/llm/stream-text";
+import { stripIndents } from "~/utils/stripIndent";
+import type { ChatRequest } from "~/routes/api.chat";
const encoder = new TextEncoder();
const decoder = new TextDecoder();
export async function action(args: ActionFunctionArgs) {
- return enhancerAction(args);
+ return enhancerAction(args);
}
-
+type EnhancePrompt = {
+ message: string;
+ model: string;
+ provider: string;
+ api_key: string;
+};
async function enhancerAction({ context, request }: ActionFunctionArgs) {
- const { message } = await request.json<{ message: string }>();
-
- try {
- const result = await streamText(
- [
- {
- role: 'user',
- content: stripIndents`
+ const { message, model, provider, api_key } =
+ await request.json();
+ const input: ChatRequest = {
+ model,
+ provider,
+ api_key,
+ messages: [
+ {
+ role: "user",
+ content: stripIndents`
I want you to improve the user prompt that is wrapped in \`\` tags.
IMPORTANT: Only respond with the improved prompt and nothing else!
@@ -27,34 +35,35 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) {
${message}
`,
- },
- ],
- context.cloudflare.env,
- );
-
- const transformStream = new TransformStream({
- transform(chunk, controller) {
- const processedChunk = decoder
- .decode(chunk)
- .split('\n')
- .filter((line) => line !== '')
- .map(parseStreamPart)
- .map((part) => part.value)
- .join('');
-
- controller.enqueue(encoder.encode(processedChunk));
- },
- });
-
- const transformedStream = result.toAIStream().pipeThrough(transformStream);
-
- return new StreamingTextResponse(transformedStream);
- } catch (error) {
- console.log(error);
-
- throw new Response(null, {
- status: 500,
- statusText: 'Internal Server Error',
- });
- }
+ },
+ ],
+ };
+ try {
+ const result = await streamText(input, context.cloudflare.env);
+
+ const transformStream = new TransformStream({
+ transform(chunk, controller) {
+ const processedChunk = decoder
+ .decode(chunk)
+ .split("\n")
+ .filter((line) => line !== "")
+ .map(parseStreamPart)
+ .map((part) => part.value)
+ .join("");
+
+ controller.enqueue(encoder.encode(processedChunk));
+ },
+ });
+
+ const transformedStream = result.toAIStream().pipeThrough(transformStream);
+
+ return new StreamingTextResponse(transformedStream);
+ } catch (error) {
+ console.log(error);
+
+ throw new Response(null, {
+ status: 500,
+ statusText: "Internal Server Error",
+ });
+ }
}
diff --git a/app/routes/api.models.ts b/app/routes/api.models.ts
index ace4ef009..4fdff3812 100644
--- a/app/routes/api.models.ts
+++ b/app/routes/api.models.ts
@@ -1,6 +1,8 @@
-import { json } from '@remix-run/cloudflare';
-import { MODEL_LIST } from '~/utils/constants';
+import { json, type LoaderFunctionArgs } from '@remix-run/cloudflare';
+import { initializeModelList } from '~/utils/tools';
-export async function loader() {
- return json(MODEL_LIST);
+export async function loader({context}: LoaderFunctionArgs) {
+ const { env } = context.cloudflare;
+ const modelList = await initializeModelList(env);
+ return json(modelList);
}
diff --git a/app/utils/constants.ts b/app/utils/constants.ts
index b48cb3442..d3a49b01d 100644
--- a/app/utils/constants.ts
+++ b/app/utils/constants.ts
@@ -1,5 +1,4 @@
-import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types';
-
+import type { ModelInfo } from './types';
export const WORK_DIR_NAME = 'project';
export const WORK_DIR = `/home/${WORK_DIR_NAME}`;
export const MODIFICATIONS_TAG_NAME = 'bolt_file_modifications';
@@ -46,50 +45,52 @@ const staticModels: ModelInfo[] = [
];
export let MODEL_LIST: ModelInfo[] = [...staticModels];
+export const PROVIDER_LIST: string[] = ['Ollama', 'OpenAILike']
-async function getOllamaModels(): Promise {
- try {
- const base_url = import.meta.env.OLLAMA_API_BASE_URL || "http://localhost:11434";
- const response = await fetch(`${base_url}/api/tags`);
- const data = await response.json() as OllamaApiResponse;
-
- return data.models.map((model: OllamaModel) => ({
- name: model.name,
- label: `${model.name} (${model.details.parameter_size})`,
- provider: 'Ollama',
- }));
- } catch (e) {
- return [];
+export function hasModel(modelName:string,provider:string): boolean {
+ for (const model of MODEL_LIST) {
+ if ( model.provider === provider && model.name === modelName) {
+ return true;
+ }
}
+ return false
}
+export const IS_SERVER = typeof window === 'undefined';
-async function getOpenAILikeModels(): Promise {
- try {
- const base_url =import.meta.env.OPENAI_LIKE_API_BASE_URL || "";
- if (!base_url) {
- return [];
- }
- const api_key = import.meta.env.OPENAI_LIKE_API_KEY ?? "";
- const response = await fetch(`${base_url}/models`, {
- headers: {
- Authorization: `Bearer ${api_key}`,
- }
- });
- const res = await response.json() as any;
- return res.data.map((model: any) => ({
- name: model.id,
- label: model.id,
- provider: 'OpenAILike',
- }));
- }catch (e) {
- return []
- }
+export function setModelList(models: ModelInfo[]): void {
+ MODEL_LIST = models;
+}
+export function getStaticModels(): ModelInfo[] {
+ return [...staticModels];
}
-async function initializeModelList(): Promise {
- const ollamaModels = await getOllamaModels();
- const openAiLikeModels = await getOpenAILikeModels();
- MODEL_LIST = [...ollamaModels,...openAiLikeModels, ...staticModels];
+export let isInitialized = false;
+
+export type ModelListInfo= {
+ modelList: ModelInfo[],
+ providerList: string[]
+}
+
+export async function initializeModelList(): Promise {
+ if (isInitialized) {
+ return getModelListInfo();
+ }
+
+ if (IS_SERVER) {
+ isInitialized = true;
+ return getModelListInfo();
+ }
+
+ isInitialized = true;
+ const response = await fetch('/api/models');
+ MODEL_LIST = (await response.json()) as ModelInfo[];
+
+ return getModelListInfo();
+}
+
+export function getModelListInfo(): ModelListInfo {
+ return {
+ modelList: MODEL_LIST,
+ providerList: [...new Set([...MODEL_LIST.map(m => m.provider), ...PROVIDER_LIST])]
+ };
}
-initializeModelList().then();
-export { getOllamaModels, getOpenAILikeModels, initializeModelList };
diff --git a/app/utils/tools.ts b/app/utils/tools.ts
new file mode 100644
index 000000000..74cbdadb4
--- /dev/null
+++ b/app/utils/tools.ts
@@ -0,0 +1,74 @@
+import type { ModelInfo, OllamaApiResponse, OllamaModel } from '~/utils/types';
+import { getStaticModels,setModelList } from '~/utils/constants';
+import { getAPIKey, getBaseURL } from '~/lib/.server/llm/api-key';
+
+
+export let MODEL_LIST: ModelInfo[] = [...getStaticModels()];
+
+
+
+let isInitialized = false;
+async function getOllamaModels(env: Env): Promise {
+ try {
+ const base_url = getBaseURL(env,"Ollama") ;
+ const response = await fetch(`${base_url}/api/tags`);
+ const data = await response.json() as OllamaApiResponse;
+ return data.models.map((model: OllamaModel) => ({
+ name: model.name,
+ label: `${model.name} (${model.details.parameter_size})`,
+ provider: 'Ollama',
+ }));
+ } catch (e) {
+ return [{
+ name: "Empty",
+ label: "Empty",
+ provider: "Ollama"
+ }];
+ }
+}
+
+async function getOpenAILikeModels(env: Env): Promise {
+ try {
+ const base_url = getBaseURL(env,"OpenAILike") ;
+ if (!base_url) {
+ return [{
+ name: "Empty",
+ label: "Empty",
+ provider: "OpenAILike"
+ }];
+ }
+ const api_key = getAPIKey(env,"OpenAILike") ?? "";
+ const response = await fetch(`${base_url}/models`, {
+ headers: {
+ Authorization: `Bearer ${api_key}`,
+ }
+ });
+ const res = await response.json() as any;
+ return res.data.map((model: any) => ({
+ name: model.id,
+ label: model.id,
+ provider: 'OpenAILike',
+ }));
+ }catch (e) {
+ return [{
+ name: "Empty",
+ label: "Empty",
+ provider: "OpenAILike"
+ }];
+ }
+}
+
+
+async function initializeModelList(env: Env): Promise {
+ if (isInitialized) {
+ return MODEL_LIST;
+ }
+ isInitialized = true;
+ const ollamaModels = await getOllamaModels(env);
+ const openAiLikeModels = await getOpenAILikeModels(env);
+ MODEL_LIST = [...getStaticModels(), ...ollamaModels, ...openAiLikeModels];
+ setModelList(MODEL_LIST);
+ return MODEL_LIST;
+}
+
+export { getOllamaModels, getOpenAILikeModels, initializeModelList };
diff --git a/package.json b/package.json
index edb2b8dad..1f9f0e845 100644
--- a/package.json
+++ b/package.json
@@ -6,13 +6,14 @@
"sideEffects": false,
"type": "module",
"scripts": {
- "deploy": "npm run build && wrangler pages deploy",
+ "deploy": "pnpm run build && wrangler pages deploy",
+ "dev:deploy": "pnpm run build && wrangler pages dev",
"build": "remix vite:build",
"dev": "remix vite:dev",
"test": "vitest --run",
"test:watch": "vitest",
"lint": "eslint --cache --cache-location ./node_modules/.cache/eslint .",
- "lint:fix": "npm run lint -- --fix",
+ "lint:fix": "pnpm run lint -- --fix",
"start": "bindings=$(./bindings.sh) && wrangler pages dev ./build/client $bindings --ip 0.0.0.0 --port 3000",
"typecheck": "tsc",
"typegen": "wrangler types",