diff --git a/app/components/chat/BaseChat.tsx b/app/components/chat/BaseChat.tsx index c1175f700..9d8fa87c4 100644 --- a/app/components/chat/BaseChat.tsx +++ b/app/components/chat/BaseChat.tsx @@ -7,12 +7,11 @@ import { Menu } from '~/components/sidebar/Menu.client'; import { IconButton } from '~/components/ui/IconButton'; import { Workbench } from '~/components/workbench/Workbench.client'; import { classNames } from '~/utils/classNames'; -import { MODEL_LIST, DEFAULT_PROVIDER } from '~/utils/constants'; import { Messages } from './Messages.client'; import { SendButton } from './SendButton.client'; -import { useState } from 'react'; import styles from './BaseChat.module.scss'; +import type { ModelInfo } from '~/utils/types'; const EXAMPLE_PROMPTS = [ { text: 'Build a todo app in React using Tailwind' }, @@ -22,36 +21,33 @@ const EXAMPLE_PROMPTS = [ { text: 'How do I center a div?' }, ]; -const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))] -const ModelSelector = ({ model, setModel, modelList, providerList }) => { - const [provider, setProvider] = useState(DEFAULT_PROVIDER); + +function ModelSelector({ model, setModel ,provider,setProvider,modelList,providerList}) { + + + + return (
); -}; +} const TEXTAREA_MIN_HEIGHT = 76; @@ -77,8 +73,12 @@ interface BaseChatProps { enhancingPrompt?: boolean; promptEnhanced?: boolean; input?: string; - model: string; - setModel: (model: string) => void; + model?: string; + setModel?: (model: string) => void; + provider?: string; + setProvider?: (provider: string) => void; + modelList?: ModelInfo[]; + providerList?: string[]; handleStop?: () => void; sendMessage?: (event: React.UIEvent, messageInput?: string) => void; handleInputChange?: (event: React.ChangeEvent) => void; @@ -100,6 +100,10 @@ export const BaseChat = React.forwardRef( input = '', model, setModel, + provider, + setProvider, + modelList, + providerList, sendMessage, handleInputChange, enhancePrompt, @@ -108,7 +112,6 @@ export const BaseChat = React.forwardRef( ref, ) => { const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200; - return (
(
( } event.preventDefault(); - sendMessage?.(event); } }} diff --git a/app/components/chat/Chat.client.tsx b/app/components/chat/Chat.client.tsx index 458bd8364..16d07d5e9 100644 --- a/app/components/chat/Chat.client.tsx +++ b/app/components/chat/Chat.client.tsx @@ -1,240 +1,320 @@ // @ts-nocheck // Preventing TS checks with files presented in the video for a better presentation. -import { useStore } from '@nanostores/react'; -import type { Message } from 'ai'; -import { useChat } from 'ai/react'; -import { useAnimate } from 'framer-motion'; -import { memo, useEffect, useRef, useState } from 'react'; -import { cssTransition, toast, ToastContainer } from 'react-toastify'; -import { useMessageParser, usePromptEnhancer, useShortcuts, useSnapScroll } from '~/lib/hooks'; -import { useChatHistory } from '~/lib/persistence'; -import { chatStore } from '~/lib/stores/chat'; -import { workbenchStore } from '~/lib/stores/workbench'; -import { fileModificationsToHTML } from '~/utils/diff'; -import { DEFAULT_MODEL } from '~/utils/constants'; -import { cubicEasingFn } from '~/utils/easings'; -import { createScopedLogger, renderLogger } from '~/utils/logger'; -import { BaseChat } from './BaseChat'; +import { useStore } from "@nanostores/react"; +import type { Message } from "ai"; +import { useChat } from "ai/react"; +import { useAnimate } from "framer-motion"; +import { memo, useEffect, useRef, useState } from "react"; +import { cssTransition, toast, ToastContainer } from "react-toastify"; +import { + useMessageParser, + usePromptEnhancer, + useShortcuts, + useSnapScroll, +} from "~/lib/hooks"; +import { useChatHistory } from "~/lib/persistence"; +import { chatStore } from "~/lib/stores/chat"; +import { workbenchStore } from "~/lib/stores/workbench"; +import { fileModificationsToHTML } from "~/utils/diff"; +import { + DEFAULT_MODEL, + DEFAULT_PROVIDER, + getModelListInfo, + initializeModelList, + isInitialized, + MODEL_LIST, + PROVIDER_LIST, +} from "~/utils/constants"; +import { cubicEasingFn } from "~/utils/easings"; +import { createScopedLogger, renderLogger } from "~/utils/logger"; +import { BaseChat } from "./BaseChat"; const toastAnimation = cssTransition({ - enter: 'animated fadeInRight', - exit: 'animated fadeOutRight', + enter: "animated fadeInRight", + exit: "animated fadeOutRight", }); -const logger = createScopedLogger('Chat'); +const logger = createScopedLogger("Chat"); export function Chat() { - renderLogger.trace('Chat'); - - const { ready, initialMessages, storeMessageHistory } = useChatHistory(); - - return ( - <> - {ready && } - { - return ( - - ); - }} - icon={({ type }) => { - /** - * @todo Handle more types if we need them. This may require extra color palettes. - */ - switch (type) { - case 'success': { - return
; - } - case 'error': { - return
; - } - } - - return undefined; - }} - position="bottom-right" - pauseOnFocusLoss - transition={toastAnimation} - /> - - ); + renderLogger.trace("Chat"); + + const { ready, initialMessages, storeMessageHistory } = useChatHistory(); + + return ( + <> + {ready && ( + + )} + { + return ( + + ); + }} + icon={({ type }) => { + /** + * @todo Handle more types if we need them. This may require extra color palettes. + */ + switch (type) { + case "success": { + return ( +
+ ); + } + case "error": { + return ( +
+ ); + } + } + + return undefined; + }} + position="bottom-right" + pauseOnFocusLoss + transition={toastAnimation} + /> + + ); } interface ChatProps { - initialMessages: Message[]; - storeMessageHistory: (messages: Message[]) => Promise; + initialMessages: Message[]; + storeMessageHistory: (messages: Message[]) => Promise; } -export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProps) => { - useShortcuts(); - - const textareaRef = useRef(null); - - const [chatStarted, setChatStarted] = useState(initialMessages.length > 0); - const [model, setModel] = useState(DEFAULT_MODEL); - - const { showChat } = useStore(chatStore); - - const [animationScope, animate] = useAnimate(); - - const { messages, isLoading, input, handleInputChange, setInput, stop, append } = useChat({ - api: '/api/chat', - onError: (error) => { - logger.error('Request failed\n\n', error); - toast.error('There was an error processing your request'); - }, - onFinish: () => { - logger.debug('Finished streaming'); - }, - initialMessages, - }); - - const { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer } = usePromptEnhancer(); - const { parsedMessages, parseMessages } = useMessageParser(); - - const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200; - - useEffect(() => { - chatStore.setKey('started', initialMessages.length > 0); - }, []); - - useEffect(() => { - parseMessages(messages, isLoading); - - if (messages.length > initialMessages.length) { - storeMessageHistory(messages).catch((error) => toast.error(error.message)); - } - }, [messages, isLoading, parseMessages]); - - const scrollTextArea = () => { - const textarea = textareaRef.current; - - if (textarea) { - textarea.scrollTop = textarea.scrollHeight; - } - }; - - const abort = () => { - stop(); - chatStore.setKey('aborted', true); - workbenchStore.abortAllActions(); - }; - - useEffect(() => { - const textarea = textareaRef.current; - - if (textarea) { - textarea.style.height = 'auto'; - - const scrollHeight = textarea.scrollHeight; - - textarea.style.height = `${Math.min(scrollHeight, TEXTAREA_MAX_HEIGHT)}px`; - textarea.style.overflowY = scrollHeight > TEXTAREA_MAX_HEIGHT ? 'auto' : 'hidden'; - } - }, [input, textareaRef]); - - const runAnimation = async () => { - if (chatStarted) { - return; - } - - await Promise.all([ - animate('#examples', { opacity: 0, display: 'none' }, { duration: 0.1 }), - animate('#intro', { opacity: 0, flex: 1 }, { duration: 0.2, ease: cubicEasingFn }), - ]); - - chatStore.setKey('started', true); - - setChatStarted(true); - }; - - const sendMessage = async (_event: React.UIEvent, messageInput?: string) => { - const _input = messageInput || input; - - if (_input.length === 0 || isLoading) { - return; - } - - /** - * @note (delm) Usually saving files shouldn't take long but it may take longer if there - * many unsaved files. In that case we need to block user input and show an indicator - * of some kind so the user is aware that something is happening. But I consider the - * happy case to be no unsaved files and I would expect users to save their changes - * before they send another message. - */ - await workbenchStore.saveAllFiles(); - - const fileModifications = workbenchStore.getFileModifcations(); - - chatStore.setKey('aborted', false); - - runAnimation(); - - if (fileModifications !== undefined) { - const diff = fileModificationsToHTML(fileModifications); - - /** - * If we have file modifications we append a new user message manually since we have to prefix - * the user input with the file modifications and we don't want the new user input to appear - * in the prompt. Using `append` is almost the same as `handleSubmit` except that we have to - * manually reset the input and we'd have to manually pass in file attachments. However, those - * aren't relevant here. - */ - append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` }); - - /** - * After sending a new message we reset all modifications since the model - * should now be aware of all the changes. - */ - workbenchStore.resetAllFileModifications(); - } else { - append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` }); - } - - setInput(''); - - resetEnhancer(); - - textareaRef.current?.blur(); - }; - - const [messageRef, scrollRef] = useSnapScroll(); - - return ( - { - if (message.role === 'user') { - return message; - } - - return { - ...message, - content: parsedMessages[i] || '', - }; - })} - enhancePrompt={() => { - enhancePrompt(input, (input) => { - setInput(input); - scrollTextArea(); - }); - }} - /> - ); -}); +export const ChatImpl = memo( + ({ initialMessages, storeMessageHistory }: ChatProps) => { + useShortcuts(); + + const textareaRef = useRef(null); + + const [chatStarted, setChatStarted] = useState(initialMessages.length > 0); + const [model, setModel] = useState(DEFAULT_MODEL); + const [provider, setProvider] = useState(DEFAULT_PROVIDER); + const list = getModelListInfo(); + const [modelList, setModelList] = useState(list.modelList); + const [providerList, setProviderList] = useState(list.providerList); + // TODO: Add API key + const [api_key, setApiKey] = useState(""); + const initialize = async () => { + if (!isInitialized) { + const { modelList, providerList } = await initializeModelList(); + setModelList(modelList); + setProviderList(providerList); + } + }; + initialize(); + + const { showChat } = useStore(chatStore); + + const [animationScope, animate] = useAnimate(); + + const { + messages, + isLoading, + input, + handleInputChange, + setInput, + stop, + append, + } = useChat({ + api: "/api/chat", + onError: (error) => { + logger.error("Request failed\n\n", error); + toast.error("There was an error processing your request"); + }, + onFinish: () => { + logger.debug("Finished streaming"); + }, + initialMessages, + }); + + const { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer } = + usePromptEnhancer(); + const { parsedMessages, parseMessages } = useMessageParser(); + + const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200; + + useEffect(() => { + chatStore.setKey("started", initialMessages.length > 0); + }, []); + + useEffect(() => { + parseMessages(messages, isLoading); + + if (messages.length > initialMessages.length) { + storeMessageHistory(messages).catch((error) => + toast.error(error.message), + ); + } + }, [messages, isLoading, parseMessages]); + + const scrollTextArea = () => { + const textarea = textareaRef.current; + + if (textarea) { + textarea.scrollTop = textarea.scrollHeight; + } + }; + + const abort = () => { + stop(); + chatStore.setKey("aborted", true); + workbenchStore.abortAllActions(); + }; + + useEffect(() => { + const textarea = textareaRef.current; + + if (textarea) { + textarea.style.height = "auto"; + + const scrollHeight = textarea.scrollHeight; + + textarea.style.height = `${Math.min(scrollHeight, TEXTAREA_MAX_HEIGHT)}px`; + textarea.style.overflowY = + scrollHeight > TEXTAREA_MAX_HEIGHT ? "auto" : "hidden"; + } + }, [input, textareaRef]); + + const runAnimation = async () => { + if (chatStarted) { + return; + } + + await Promise.all([ + animate( + "#examples", + { opacity: 0, display: "none" }, + { duration: 0.1 }, + ), + animate( + "#intro", + { opacity: 0, flex: 1 }, + { duration: 0.2, ease: cubicEasingFn }, + ), + ]); + + chatStore.setKey("started", true); + + setChatStarted(true); + }; + + const sendMessage = async ( + _event: React.UIEvent, + messageInput?: string, + ) => { + const _input = messageInput || input; + + if (_input.length === 0 || isLoading) { + return; + } + + /** + * @note (delm) Usually saving files shouldn't take long but it may take longer if there + * many unsaved files. In that case we need to block user input and show an indicator + * of some kind so the user is aware that something is happening. But I consider the + * happy case to be no unsaved files and I would expect users to save their changes + * before they send another message. + */ + await workbenchStore.saveAllFiles(); + + const fileModifications = workbenchStore.getFileModifcations(); + + chatStore.setKey("aborted", false); + + runAnimation(); + const message = { role: "user", content: "" }; + const body = { + model, + provider, + api_key, + }; + if (fileModifications !== undefined) { + const diff = fileModificationsToHTML(fileModifications); + + /** + * If we have file modifications we append a new user message manually since we have to prefix + * the user input with the file modifications and we don't want the new user input to appear + * in the prompt. Using `append` is almost the same as `handleSubmit` except that we have to + * manually reset the input and we'd have to manually pass in file attachments. However, those + * aren't relevant here. + */ + message.content = `${diff}\n\n${_input}`; + + /** + * After sending a new message we reset all modifications since the model + * should now be aware of all the changes. + */ + workbenchStore.resetAllFileModifications(); + } else { + message.content = _input; + } + append(message, { + body, + }); + setInput(""); + + resetEnhancer(); + + textareaRef.current?.blur(); + }; + + const [messageRef, scrollRef] = useSnapScroll(); + + return ( + { + if (message.role === "user") { + return message; + } + + return { + ...message, + content: parsedMessages[i] || "", + }; + })} + enhancePrompt={() => { + enhancePrompt( + { + input, + model, + provider, + api_key, + }, + (input) => { + setInput(input); + scrollTextArea(); + }, + ); + }} + /> + ); + }, +); diff --git a/app/entry.server.tsx b/app/entry.server.tsx index be2b42bf0..0217634a1 100644 --- a/app/entry.server.tsx +++ b/app/entry.server.tsx @@ -5,7 +5,6 @@ import { renderToReadableStream } from 'react-dom/server'; import { renderHeadToString } from 'remix-island'; import { Head } from './root'; import { themeStore } from '~/lib/stores/theme'; -import { initializeModelList } from '~/utils/constants'; export default async function handleRequest( request: Request, @@ -14,7 +13,7 @@ export default async function handleRequest( remixContext: EntryContext, _loadContext: AppLoadContext, ) { - await initializeModelList(); + const readable = await renderToReadableStream(, { signal: request.signal, diff --git a/app/lib/.server/llm/model.ts b/app/lib/.server/llm/model.ts index 390d57aeb..8338ab387 100644 --- a/app/lib/.server/llm/model.ts +++ b/app/lib/.server/llm/model.ts @@ -6,7 +6,6 @@ import { createOpenAI } from '@ai-sdk/openai'; import { createGoogleGenerativeAI } from '@ai-sdk/google'; import { ollama } from 'ollama-ai-provider'; import { createOpenRouter } from "@openrouter/ai-sdk-provider"; -import { mistral } from '@ai-sdk/mistral'; import { createMistral } from '@ai-sdk/mistral'; export function getAnthropicModel(apiKey: string, model: string) { @@ -80,27 +79,27 @@ export function getOpenRouterModel(apiKey: string, model: string) { return openRouter.chat(model); } -export function getModel(provider: string, model: string, env: Env) { - const apiKey = getAPIKey(env, provider); +export function getModel(provider: string, model: string,apiKey:string, env: Env) { + const _apiKey = apiKey || getAPIKey(env, provider); const baseURL = getBaseURL(env, provider); switch (provider) { case 'Anthropic': - return getAnthropicModel(apiKey, model); + return getAnthropicModel(_apiKey, model); case 'OpenAI': - return getOpenAIModel(apiKey, model); + return getOpenAIModel(_apiKey, model); case 'Groq': - return getGroqModel(apiKey, model); + return getGroqModel(_apiKey, model); case 'OpenRouter': - return getOpenRouterModel(apiKey, model); + return getOpenRouterModel(_apiKey, model); case 'Google': - return getGoogleModel(apiKey, model) + return getGoogleModel(_apiKey, model) case 'OpenAILike': - return getOpenAILikeModel(baseURL,apiKey, model); + return getOpenAILikeModel(baseURL,_apiKey, model); case 'Deepseek': - return getDeepseekModel(apiKey, model) + return getDeepseekModel(_apiKey, model) case 'Mistral': - return getMistralModel(apiKey, model); + return getMistralModel(_apiKey, model); default: return getOllamaModel(baseURL, model); } diff --git a/app/lib/.server/llm/stream-text.ts b/app/lib/.server/llm/stream-text.ts index de3d5bfa8..f1e057f41 100644 --- a/app/lib/.server/llm/stream-text.ts +++ b/app/lib/.server/llm/stream-text.ts @@ -1,66 +1,64 @@ // @ts-nocheck // Preventing TS checks with files presented in the video for a better presentation. -import { streamText as _streamText, convertToCoreMessages } from 'ai'; -import { getModel } from '~/lib/.server/llm/model'; -import { MAX_TOKENS } from './constants'; -import { getSystemPrompt } from './prompts'; -import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants'; +import { streamText as _streamText, convertToCoreMessages } from "ai"; +import { getModel } from "~/lib/.server/llm/model"; +import { MAX_TOKENS } from "./constants"; +import { getSystemPrompt } from "./prompts"; +import { DEFAULT_MODEL, DEFAULT_PROVIDER, hasModel } from "~/utils/constants"; +import type { ChatRequest } from "~/routes/api.chat"; interface ToolResult { - toolCallId: string; - toolName: Name; - args: Args; - result: Result; + toolCallId: string; + toolName: Name; + args: Args; + result: Result; } interface Message { - role: 'user' | 'assistant'; - content: string; - toolInvocations?: ToolResult[]; - model?: string; + role: "user" | "assistant"; + content: string; + toolInvocations?: ToolResult[]; + model?: string; } export type Messages = Message[]; -export type StreamingOptions = Omit[0], 'model'>; +export type StreamingOptions = Omit[0], "model">; -function extractModelFromMessage(message: Message): { model: string; content: string } { - const modelRegex = /^\[Model: (.*?)\]\n\n/; - const match = message.content.match(modelRegex); +// function extractModelFromMessage(message: Message): { model: string; content: string } { +// const modelRegex = /^\[Model: (.*?)Provider: (.*?)\]\n\n/; +// const match = message.content.match(modelRegex); +// +// if (!match) { +// return { model: DEFAULT_MODEL, content: message.content,provider: DEFAULT_PROVIDER }; +// } +// const [_,model,provider] = match; +// const content = message.content.replace(modelRegex, ''); +// return { model, content ,provider}; +// // Default model if not specified +// +// } - if (match) { - const model = match[1]; - const content = message.content.replace(modelRegex, ''); - return { model, content }; - } +export function streamText( + chatRequest: ChatRequest, + env: Env, + options?: StreamingOptions, +) { + const { messages, model, api_key, provider } = chatRequest; - // Default model if not specified - return { model: DEFAULT_MODEL, content: message.content }; -} - -export function streamText(messages: Messages, env: Env, options?: StreamingOptions) { - let currentModel = DEFAULT_MODEL; - const processedMessages = messages.map((message) => { - if (message.role === 'user') { - const { model, content } = extractModelFromMessage(message); - if (model && MODEL_LIST.find((m) => m.name === model)) { - currentModel = model; // Update the current model - } - return { ...message, content }; - } - return message; - }); - - const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER; + const _hasModel = hasModel(model, provider); + let currentModel = _hasModel ? model : DEFAULT_MODEL; + let currentProvider = _hasModel ? provider : DEFAULT_PROVIDER; - return _streamText({ - model: getModel(provider, currentModel, env), - system: getSystemPrompt(), - maxTokens: MAX_TOKENS, - // headers: { - // 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15', - // }, - messages: convertToCoreMessages(processedMessages), - ...options, - }); + const coreMessages = convertToCoreMessages(messages); + return _streamText({ + model: getModel(currentProvider, currentModel, api_key, env), + system: getSystemPrompt(), + maxTokens: MAX_TOKENS, + // headers: { + // 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15', + // }, + messages: coreMessages, + ...options, + }); } diff --git a/app/lib/hooks/usePromptEnhancer.ts b/app/lib/hooks/usePromptEnhancer.ts index f376cc0cd..780989a84 100644 --- a/app/lib/hooks/usePromptEnhancer.ts +++ b/app/lib/hooks/usePromptEnhancer.ts @@ -1,71 +1,83 @@ -import { useState } from 'react'; -import { createScopedLogger } from '~/utils/logger'; +import { useState } from "react"; +import { createScopedLogger } from "~/utils/logger"; -const logger = createScopedLogger('usePromptEnhancement'); +const logger = createScopedLogger("usePromptEnhancement"); export function usePromptEnhancer() { - const [enhancingPrompt, setEnhancingPrompt] = useState(false); - const [promptEnhanced, setPromptEnhanced] = useState(false); - - const resetEnhancer = () => { - setEnhancingPrompt(false); - setPromptEnhanced(false); - }; - - const enhancePrompt = async (input: string, setInput: (value: string) => void) => { - setEnhancingPrompt(true); - setPromptEnhanced(false); - - const response = await fetch('/api/enhancer', { - method: 'POST', - body: JSON.stringify({ - message: input, - }), - }); - - const reader = response.body?.getReader(); - - const originalInput = input; - - if (reader) { - const decoder = new TextDecoder(); - - let _input = ''; - let _error; - - try { - setInput(''); - - while (true) { - const { value, done } = await reader.read(); - - if (done) { - break; - } - - _input += decoder.decode(value); - - logger.trace('Set input', _input); - - setInput(_input); - } - } catch (error) { - _error = error; - setInput(originalInput); - } finally { - if (_error) { - logger.error(_error); - } - - setEnhancingPrompt(false); - setPromptEnhanced(true); - - setTimeout(() => { - setInput(_input); - }); - } - } - }; - - return { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer }; + const [enhancingPrompt, setEnhancingPrompt] = useState(false); + const [promptEnhanced, setPromptEnhanced] = useState(false); + + const resetEnhancer = () => { + setEnhancingPrompt(false); + setPromptEnhanced(false); + }; + type EnhancePrompt = { + input: string; + model: string; + provider: string; + api_key: string; + }; + + const enhancePrompt = async ( + { input, model, provider, api_key }: EnhancePrompt, + setInput: (value: string) => void, + ) => { + setEnhancingPrompt(true); + setPromptEnhanced(false); + + const response = await fetch("/api/enhancer", { + method: "POST", + body: JSON.stringify({ + message: input, + model, + provider, + api_key, + }), + }); + + const reader = response.body?.getReader(); + + const originalInput = input; + + if (reader) { + const decoder = new TextDecoder(); + + let _input = ""; + let _error; + + try { + setInput(""); + + while (true) { + const { value, done } = await reader.read(); + + if (done) { + break; + } + + _input += decoder.decode(value); + + logger.trace("Set input", _input); + + setInput(_input); + } + } catch (error) { + _error = error; + setInput(originalInput); + } finally { + if (_error) { + logger.error(_error); + } + + setEnhancingPrompt(false); + setPromptEnhanced(true); + + setTimeout(() => { + setInput(_input); + }); + } + } + }; + + return { enhancingPrompt, promptEnhanced, enhancePrompt, resetEnhancer }; } diff --git a/app/routes/_index.tsx b/app/routes/_index.tsx index 86d73409c..7f197908a 100644 --- a/app/routes/_index.tsx +++ b/app/routes/_index.tsx @@ -3,6 +3,8 @@ import { ClientOnly } from 'remix-utils/client-only'; import { BaseChat } from '~/components/chat/BaseChat'; import { Chat } from '~/components/chat/Chat.client'; import { Header } from '~/components/header/Header'; +import { useState } from 'react'; +import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST } from '~/utils/constants'; export const meta: MetaFunction = () => { return [{ title: 'Bolt' }, { name: 'description', content: 'Talk with Bolt, an AI assistant from StackBlitz' }]; @@ -11,10 +13,14 @@ export const meta: MetaFunction = () => { export const loader = () => json({}); export default function Index() { + const [model, setModel] = useState(DEFAULT_MODEL); + const [provider, setProvider] = useState(DEFAULT_PROVIDER); + const [modelList, setModelList] = useState(MODEL_LIST); + const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]); return (
- }>{() => } + }>{() => }
); } diff --git a/app/routes/api.chat.ts b/app/routes/api.chat.ts index c455c193d..69db9a912 100644 --- a/app/routes/api.chat.ts +++ b/app/routes/api.chat.ts @@ -1,61 +1,79 @@ // @ts-nocheck // Preventing TS checks with files presented in the video for a better presentation. -import { type ActionFunctionArgs } from '@remix-run/cloudflare'; -import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants'; -import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts'; -import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text'; -import SwitchableStream from '~/lib/.server/llm/switchable-stream'; +import { type ActionFunctionArgs } from "@remix-run/cloudflare"; +import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from "~/lib/.server/llm/constants"; +import { CONTINUE_PROMPT } from "~/lib/.server/llm/prompts"; +import { + streamText, + type Messages, + type StreamingOptions, +} from "~/lib/.server/llm/stream-text"; +import SwitchableStream from "~/lib/.server/llm/switchable-stream"; export async function action(args: ActionFunctionArgs) { - return chatAction(args); + return chatAction(args); } +export type ChatRequest = { + messages: Messages; + model: string; + provider: string; + api_key: string; +}; async function chatAction({ context, request }: ActionFunctionArgs) { - const { messages } = await request.json<{ messages: Messages }>(); + const chatRequest = await request.json(); + const stream = new SwitchableStream(); + try { + const options: StreamingOptions = { + toolChoice: "none", + onFinish: async ({ text: content, finishReason }) => { + if (finishReason !== "length") { + return stream.close(); + } - const stream = new SwitchableStream(); + if (stream.switches >= MAX_RESPONSE_SEGMENTS) { + throw Error("Cannot continue message: Maximum segments reached"); + } - try { - const options: StreamingOptions = { - toolChoice: 'none', - onFinish: async ({ text: content, finishReason }) => { - if (finishReason !== 'length') { - return stream.close(); - } + const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches; - if (stream.switches >= MAX_RESPONSE_SEGMENTS) { - throw Error('Cannot continue message: Maximum segments reached'); - } + console.log( + `Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`, + ); - const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches; + messages.push({ role: "assistant", content }); + messages.push({ role: "user", content: CONTINUE_PROMPT }); - console.log(`Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`); + const result = await streamText( + chatRequest, + context.cloudflare.env, + options, + ); - messages.push({ role: 'assistant', content }); - messages.push({ role: 'user', content: CONTINUE_PROMPT }); + return stream.switchSource(result.toAIStream()); + }, + }; - const result = await streamText(messages, context.cloudflare.env, options); + const result = await streamText( + chatRequest, + context.cloudflare.env, + options, + ); - return stream.switchSource(result.toAIStream()); - }, - }; + stream.switchSource(result.toAIStream()); - const result = await streamText(messages, context.cloudflare.env, options); + return new Response(stream.readable, { + status: 200, + headers: { + contentType: "text/plain; charset=utf-8", + }, + }); + } catch (error) { + console.log(error); - stream.switchSource(result.toAIStream()); - - return new Response(stream.readable, { - status: 200, - headers: { - contentType: 'text/plain; charset=utf-8', - }, - }); - } catch (error) { - console.log(error); - - throw new Response(null, { - status: 500, - statusText: 'Internal Server Error', - }); - } + throw new Response(null, { + status: 500, + statusText: "Internal Server Error", + }); + } } diff --git a/app/routes/api.enhancer.ts b/app/routes/api.enhancer.ts index 5c8175ca3..d555aac98 100644 --- a/app/routes/api.enhancer.ts +++ b/app/routes/api.enhancer.ts @@ -1,24 +1,32 @@ -import { type ActionFunctionArgs } from '@remix-run/cloudflare'; -import { StreamingTextResponse, parseStreamPart } from 'ai'; -import { streamText } from '~/lib/.server/llm/stream-text'; -import { stripIndents } from '~/utils/stripIndent'; +import { type ActionFunctionArgs } from "@remix-run/cloudflare"; +import { StreamingTextResponse, parseStreamPart } from "ai"; +import { type Messages, streamText } from "~/lib/.server/llm/stream-text"; +import { stripIndents } from "~/utils/stripIndent"; +import type { ChatRequest } from "~/routes/api.chat"; const encoder = new TextEncoder(); const decoder = new TextDecoder(); export async function action(args: ActionFunctionArgs) { - return enhancerAction(args); + return enhancerAction(args); } - +type EnhancePrompt = { + message: string; + model: string; + provider: string; + api_key: string; +}; async function enhancerAction({ context, request }: ActionFunctionArgs) { - const { message } = await request.json<{ message: string }>(); - - try { - const result = await streamText( - [ - { - role: 'user', - content: stripIndents` + const { message, model, provider, api_key } = + await request.json(); + const input: ChatRequest = { + model, + provider, + api_key, + messages: [ + { + role: "user", + content: stripIndents` I want you to improve the user prompt that is wrapped in \`\` tags. IMPORTANT: Only respond with the improved prompt and nothing else! @@ -27,34 +35,35 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) { ${message} `, - }, - ], - context.cloudflare.env, - ); - - const transformStream = new TransformStream({ - transform(chunk, controller) { - const processedChunk = decoder - .decode(chunk) - .split('\n') - .filter((line) => line !== '') - .map(parseStreamPart) - .map((part) => part.value) - .join(''); - - controller.enqueue(encoder.encode(processedChunk)); - }, - }); - - const transformedStream = result.toAIStream().pipeThrough(transformStream); - - return new StreamingTextResponse(transformedStream); - } catch (error) { - console.log(error); - - throw new Response(null, { - status: 500, - statusText: 'Internal Server Error', - }); - } + }, + ], + }; + try { + const result = await streamText(input, context.cloudflare.env); + + const transformStream = new TransformStream({ + transform(chunk, controller) { + const processedChunk = decoder + .decode(chunk) + .split("\n") + .filter((line) => line !== "") + .map(parseStreamPart) + .map((part) => part.value) + .join(""); + + controller.enqueue(encoder.encode(processedChunk)); + }, + }); + + const transformedStream = result.toAIStream().pipeThrough(transformStream); + + return new StreamingTextResponse(transformedStream); + } catch (error) { + console.log(error); + + throw new Response(null, { + status: 500, + statusText: "Internal Server Error", + }); + } } diff --git a/app/routes/api.models.ts b/app/routes/api.models.ts index ace4ef009..4fdff3812 100644 --- a/app/routes/api.models.ts +++ b/app/routes/api.models.ts @@ -1,6 +1,8 @@ -import { json } from '@remix-run/cloudflare'; -import { MODEL_LIST } from '~/utils/constants'; +import { json, type LoaderFunctionArgs } from '@remix-run/cloudflare'; +import { initializeModelList } from '~/utils/tools'; -export async function loader() { - return json(MODEL_LIST); +export async function loader({context}: LoaderFunctionArgs) { + const { env } = context.cloudflare; + const modelList = await initializeModelList(env); + return json(modelList); } diff --git a/app/utils/constants.ts b/app/utils/constants.ts index b48cb3442..d3a49b01d 100644 --- a/app/utils/constants.ts +++ b/app/utils/constants.ts @@ -1,5 +1,4 @@ -import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types'; - +import type { ModelInfo } from './types'; export const WORK_DIR_NAME = 'project'; export const WORK_DIR = `/home/${WORK_DIR_NAME}`; export const MODIFICATIONS_TAG_NAME = 'bolt_file_modifications'; @@ -46,50 +45,52 @@ const staticModels: ModelInfo[] = [ ]; export let MODEL_LIST: ModelInfo[] = [...staticModels]; +export const PROVIDER_LIST: string[] = ['Ollama', 'OpenAILike'] -async function getOllamaModels(): Promise { - try { - const base_url = import.meta.env.OLLAMA_API_BASE_URL || "http://localhost:11434"; - const response = await fetch(`${base_url}/api/tags`); - const data = await response.json() as OllamaApiResponse; - - return data.models.map((model: OllamaModel) => ({ - name: model.name, - label: `${model.name} (${model.details.parameter_size})`, - provider: 'Ollama', - })); - } catch (e) { - return []; +export function hasModel(modelName:string,provider:string): boolean { + for (const model of MODEL_LIST) { + if ( model.provider === provider && model.name === modelName) { + return true; + } } + return false } +export const IS_SERVER = typeof window === 'undefined'; -async function getOpenAILikeModels(): Promise { - try { - const base_url =import.meta.env.OPENAI_LIKE_API_BASE_URL || ""; - if (!base_url) { - return []; - } - const api_key = import.meta.env.OPENAI_LIKE_API_KEY ?? ""; - const response = await fetch(`${base_url}/models`, { - headers: { - Authorization: `Bearer ${api_key}`, - } - }); - const res = await response.json() as any; - return res.data.map((model: any) => ({ - name: model.id, - label: model.id, - provider: 'OpenAILike', - })); - }catch (e) { - return [] - } +export function setModelList(models: ModelInfo[]): void { + MODEL_LIST = models; +} +export function getStaticModels(): ModelInfo[] { + return [...staticModels]; } -async function initializeModelList(): Promise { - const ollamaModels = await getOllamaModels(); - const openAiLikeModels = await getOpenAILikeModels(); - MODEL_LIST = [...ollamaModels,...openAiLikeModels, ...staticModels]; +export let isInitialized = false; + +export type ModelListInfo= { + modelList: ModelInfo[], + providerList: string[] +} + +export async function initializeModelList(): Promise { + if (isInitialized) { + return getModelListInfo(); + } + + if (IS_SERVER) { + isInitialized = true; + return getModelListInfo(); + } + + isInitialized = true; + const response = await fetch('/api/models'); + MODEL_LIST = (await response.json()) as ModelInfo[]; + + return getModelListInfo(); +} + +export function getModelListInfo(): ModelListInfo { + return { + modelList: MODEL_LIST, + providerList: [...new Set([...MODEL_LIST.map(m => m.provider), ...PROVIDER_LIST])] + }; } -initializeModelList().then(); -export { getOllamaModels, getOpenAILikeModels, initializeModelList }; diff --git a/app/utils/tools.ts b/app/utils/tools.ts new file mode 100644 index 000000000..74cbdadb4 --- /dev/null +++ b/app/utils/tools.ts @@ -0,0 +1,74 @@ +import type { ModelInfo, OllamaApiResponse, OllamaModel } from '~/utils/types'; +import { getStaticModels,setModelList } from '~/utils/constants'; +import { getAPIKey, getBaseURL } from '~/lib/.server/llm/api-key'; + + +export let MODEL_LIST: ModelInfo[] = [...getStaticModels()]; + + + +let isInitialized = false; +async function getOllamaModels(env: Env): Promise { + try { + const base_url = getBaseURL(env,"Ollama") ; + const response = await fetch(`${base_url}/api/tags`); + const data = await response.json() as OllamaApiResponse; + return data.models.map((model: OllamaModel) => ({ + name: model.name, + label: `${model.name} (${model.details.parameter_size})`, + provider: 'Ollama', + })); + } catch (e) { + return [{ + name: "Empty", + label: "Empty", + provider: "Ollama" + }]; + } +} + +async function getOpenAILikeModels(env: Env): Promise { + try { + const base_url = getBaseURL(env,"OpenAILike") ; + if (!base_url) { + return [{ + name: "Empty", + label: "Empty", + provider: "OpenAILike" + }]; + } + const api_key = getAPIKey(env,"OpenAILike") ?? ""; + const response = await fetch(`${base_url}/models`, { + headers: { + Authorization: `Bearer ${api_key}`, + } + }); + const res = await response.json() as any; + return res.data.map((model: any) => ({ + name: model.id, + label: model.id, + provider: 'OpenAILike', + })); + }catch (e) { + return [{ + name: "Empty", + label: "Empty", + provider: "OpenAILike" + }]; + } +} + + +async function initializeModelList(env: Env): Promise { + if (isInitialized) { + return MODEL_LIST; + } + isInitialized = true; + const ollamaModels = await getOllamaModels(env); + const openAiLikeModels = await getOpenAILikeModels(env); + MODEL_LIST = [...getStaticModels(), ...ollamaModels, ...openAiLikeModels]; + setModelList(MODEL_LIST); + return MODEL_LIST; +} + +export { getOllamaModels, getOpenAILikeModels, initializeModelList }; diff --git a/package.json b/package.json index edb2b8dad..1f9f0e845 100644 --- a/package.json +++ b/package.json @@ -6,13 +6,14 @@ "sideEffects": false, "type": "module", "scripts": { - "deploy": "npm run build && wrangler pages deploy", + "deploy": "pnpm run build && wrangler pages deploy", + "dev:deploy": "pnpm run build && wrangler pages dev", "build": "remix vite:build", "dev": "remix vite:dev", "test": "vitest --run", "test:watch": "vitest", "lint": "eslint --cache --cache-location ./node_modules/.cache/eslint .", - "lint:fix": "npm run lint -- --fix", + "lint:fix": "pnpm run lint -- --fix", "start": "bindings=$(./bindings.sh) && wrangler pages dev ./build/client $bindings --ip 0.0.0.0 --port 3000", "typecheck": "tsc", "typegen": "wrangler types",