From 1bec7438dc4d9720a3d113417ba2466bb23fb83a Mon Sep 17 00:00:00 2001 From: Gaoyao Massimo Hu Date: Sat, 26 Oct 2024 15:26:20 +0100 Subject: [PATCH] Update constants.ts and fix model name bug --- app/components/chat/BaseChat.tsx | 52 ++++++++++---------- app/components/chat/Chat.client.tsx | 23 +++++++-- app/entry.server.tsx | 3 +- app/lib/.server/llm/model.ts | 1 - app/lib/.server/llm/stream-text.ts | 37 ++++++++------- app/routes/_index.tsx | 8 +++- app/routes/api.models.ts | 10 ++-- app/utils/constants.ts | 70 +++++++++++---------------- app/utils/tools.ts | 74 +++++++++++++++++++++++++++++ package.json | 5 +- 10 files changed, 187 insertions(+), 96 deletions(-) create mode 100644 app/utils/tools.ts diff --git a/app/components/chat/BaseChat.tsx b/app/components/chat/BaseChat.tsx index c1175f700..9d8fa87c4 100644 --- a/app/components/chat/BaseChat.tsx +++ b/app/components/chat/BaseChat.tsx @@ -7,12 +7,11 @@ import { Menu } from '~/components/sidebar/Menu.client'; import { IconButton } from '~/components/ui/IconButton'; import { Workbench } from '~/components/workbench/Workbench.client'; import { classNames } from '~/utils/classNames'; -import { MODEL_LIST, DEFAULT_PROVIDER } from '~/utils/constants'; import { Messages } from './Messages.client'; import { SendButton } from './SendButton.client'; -import { useState } from 'react'; import styles from './BaseChat.module.scss'; +import type { ModelInfo } from '~/utils/types'; const EXAMPLE_PROMPTS = [ { text: 'Build a todo app in React using Tailwind' }, @@ -22,36 +21,33 @@ const EXAMPLE_PROMPTS = [ { text: 'How do I center a div?' }, ]; -const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))] -const ModelSelector = ({ model, setModel, modelList, providerList }) => { - const [provider, setProvider] = useState(DEFAULT_PROVIDER); + +function ModelSelector({ model, setModel ,provider,setProvider,modelList,providerList}) { + + + + return (
); -}; +} const TEXTAREA_MIN_HEIGHT = 76; @@ -77,8 +73,12 @@ interface BaseChatProps { enhancingPrompt?: boolean; promptEnhanced?: boolean; input?: string; - model: string; - setModel: (model: string) => void; + model?: string; + setModel?: (model: string) => void; + provider?: string; + setProvider?: (provider: string) => void; + modelList?: ModelInfo[]; + providerList?: string[]; handleStop?: () => void; sendMessage?: (event: React.UIEvent, messageInput?: string) => void; handleInputChange?: (event: React.ChangeEvent) => void; @@ -100,6 +100,10 @@ export const BaseChat = React.forwardRef( input = '', model, setModel, + provider, + setProvider, + modelList, + providerList, sendMessage, handleInputChange, enhancePrompt, @@ -108,7 +112,6 @@ export const BaseChat = React.forwardRef( ref, ) => { const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200; - return (
(
( } event.preventDefault(); - sendMessage?.(event); } }} diff --git a/app/components/chat/Chat.client.tsx b/app/components/chat/Chat.client.tsx index 458bd8364..ed4897b4c 100644 --- a/app/components/chat/Chat.client.tsx +++ b/app/components/chat/Chat.client.tsx @@ -11,7 +11,7 @@ import { useChatHistory } from '~/lib/persistence'; import { chatStore } from '~/lib/stores/chat'; import { workbenchStore } from '~/lib/stores/workbench'; import { fileModificationsToHTML } from '~/utils/diff'; -import { DEFAULT_MODEL } from '~/utils/constants'; +import { DEFAULT_MODEL, DEFAULT_PROVIDER, initializeModelList, isInitialized, MODEL_LIST } from '~/utils/constants'; import { cubicEasingFn } from '~/utils/easings'; import { createScopedLogger, renderLogger } from '~/utils/logger'; import { BaseChat } from './BaseChat'; @@ -74,6 +74,19 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp const [chatStarted, setChatStarted] = useState(initialMessages.length > 0); const [model, setModel] = useState(DEFAULT_MODEL); + const [provider, setProvider] = useState(DEFAULT_PROVIDER); + const [modelList, setModelList] = useState(MODEL_LIST); + const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]); + const initialize = async () => { + if (!isInitialized) { + const models= await initializeModelList(); + const modelList = models; + const providerList = [...new Set([...models.map((m) => m.provider),"Ollama","OpenAILike"])]; + setModelList(modelList); + setProviderList(providerList); + } + }; + initialize(); const { showChat } = useStore(chatStore); @@ -182,7 +195,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp * manually reset the input and we'd have to manually pass in file attachments. However, those * aren't relevant here. */ - append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` }); + append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${diff}\n\n${_input}` }); /** * After sending a new message we reset all modifications since the model @@ -190,7 +203,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp */ workbenchStore.resetAllFileModifications(); } else { - append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` }); + append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${_input}` }); } setInput(''); @@ -215,6 +228,10 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp sendMessage={sendMessage} model={model} setModel={setModel} + provider={provider} + setProvider={setProvider} + modelList={modelList} + providerList={providerList} messageRef={messageRef} scrollRef={scrollRef} handleInputChange={handleInputChange} diff --git a/app/entry.server.tsx b/app/entry.server.tsx index be2b42bf0..0217634a1 100644 --- a/app/entry.server.tsx +++ b/app/entry.server.tsx @@ -5,7 +5,6 @@ import { renderToReadableStream } from 'react-dom/server'; import { renderHeadToString } from 'remix-island'; import { Head } from './root'; import { themeStore } from '~/lib/stores/theme'; -import { initializeModelList } from '~/utils/constants'; export default async function handleRequest( request: Request, @@ -14,7 +13,7 @@ export default async function handleRequest( remixContext: EntryContext, _loadContext: AppLoadContext, ) { - await initializeModelList(); + const readable = await renderToReadableStream(, { signal: request.signal, diff --git a/app/lib/.server/llm/model.ts b/app/lib/.server/llm/model.ts index 390d57aeb..426de18d0 100644 --- a/app/lib/.server/llm/model.ts +++ b/app/lib/.server/llm/model.ts @@ -6,7 +6,6 @@ import { createOpenAI } from '@ai-sdk/openai'; import { createGoogleGenerativeAI } from '@ai-sdk/google'; import { ollama } from 'ollama-ai-provider'; import { createOpenRouter } from "@openrouter/ai-sdk-provider"; -import { mistral } from '@ai-sdk/mistral'; import { createMistral } from '@ai-sdk/mistral'; export function getAnthropicModel(apiKey: string, model: string) { diff --git a/app/lib/.server/llm/stream-text.ts b/app/lib/.server/llm/stream-text.ts index de3d5bfa8..e3f233bcf 100644 --- a/app/lib/.server/llm/stream-text.ts +++ b/app/lib/.server/llm/stream-text.ts @@ -4,7 +4,7 @@ import { streamText as _streamText, convertToCoreMessages } from 'ai'; import { getModel } from '~/lib/.server/llm/model'; import { MAX_TOKENS } from './constants'; import { getSystemPrompt } from './prompts'; -import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants'; +import { DEFAULT_MODEL, DEFAULT_PROVIDER, hasModel } from '~/utils/constants'; interface ToolResult { toolCallId: string; @@ -25,42 +25,47 @@ export type Messages = Message[]; export type StreamingOptions = Omit[0], 'model'>; function extractModelFromMessage(message: Message): { model: string; content: string } { - const modelRegex = /^\[Model: (.*?)\]\n\n/; + const modelRegex = /^\[Model: (.*?)Provider: (.*?)\]\n\n/; const match = message.content.match(modelRegex); - if (match) { - const model = match[1]; - const content = message.content.replace(modelRegex, ''); - return { model, content }; + if (!match) { + return { model: DEFAULT_MODEL, content: message.content,provider: DEFAULT_PROVIDER }; } - + const [_,model,provider] = match; + const content = message.content.replace(modelRegex, ''); + return { model, content ,provider}; // Default model if not specified - return { model: DEFAULT_MODEL, content: message.content }; + } export function streamText(messages: Messages, env: Env, options?: StreamingOptions) { let currentModel = DEFAULT_MODEL; + let currentProvider = DEFAULT_PROVIDER; + const lastMessage = messages.findLast((message) => message.role === 'user'); + if (lastMessage) { + const { model, provider } = extractModelFromMessage(lastMessage); + if (hasModel(model, provider)) { + currentModel = model; + currentProvider = provider; + } + } const processedMessages = messages.map((message) => { if (message.role === 'user') { - const { model, content } = extractModelFromMessage(message); - if (model && MODEL_LIST.find((m) => m.name === model)) { - currentModel = model; // Update the current model - } + const { content } = extractModelFromMessage(message); return { ...message, content }; } return message; }); - const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER; - + const coreMessages = convertToCoreMessages(processedMessages); return _streamText({ - model: getModel(provider, currentModel, env), + model: getModel(currentProvider, currentModel, env), system: getSystemPrompt(), maxTokens: MAX_TOKENS, // headers: { // 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15', // }, - messages: convertToCoreMessages(processedMessages), + messages: coreMessages, ...options, }); } diff --git a/app/routes/_index.tsx b/app/routes/_index.tsx index 86d73409c..7f197908a 100644 --- a/app/routes/_index.tsx +++ b/app/routes/_index.tsx @@ -3,6 +3,8 @@ import { ClientOnly } from 'remix-utils/client-only'; import { BaseChat } from '~/components/chat/BaseChat'; import { Chat } from '~/components/chat/Chat.client'; import { Header } from '~/components/header/Header'; +import { useState } from 'react'; +import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST } from '~/utils/constants'; export const meta: MetaFunction = () => { return [{ title: 'Bolt' }, { name: 'description', content: 'Talk with Bolt, an AI assistant from StackBlitz' }]; @@ -11,10 +13,14 @@ export const meta: MetaFunction = () => { export const loader = () => json({}); export default function Index() { + const [model, setModel] = useState(DEFAULT_MODEL); + const [provider, setProvider] = useState(DEFAULT_PROVIDER); + const [modelList, setModelList] = useState(MODEL_LIST); + const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]); return (
- }>{() => } + }>{() => }
); } diff --git a/app/routes/api.models.ts b/app/routes/api.models.ts index ace4ef009..4fdff3812 100644 --- a/app/routes/api.models.ts +++ b/app/routes/api.models.ts @@ -1,6 +1,8 @@ -import { json } from '@remix-run/cloudflare'; -import { MODEL_LIST } from '~/utils/constants'; +import { json, type LoaderFunctionArgs } from '@remix-run/cloudflare'; +import { initializeModelList } from '~/utils/tools'; -export async function loader() { - return json(MODEL_LIST); +export async function loader({context}: LoaderFunctionArgs) { + const { env } = context.cloudflare; + const modelList = await initializeModelList(env); + return json(modelList); } diff --git a/app/utils/constants.ts b/app/utils/constants.ts index b48cb3442..8ba57905b 100644 --- a/app/utils/constants.ts +++ b/app/utils/constants.ts @@ -1,5 +1,4 @@ -import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types'; - +import type { ModelInfo } from './types'; export const WORK_DIR_NAME = 'project'; export const WORK_DIR = `/home/${WORK_DIR_NAME}`; export const MODIFICATIONS_TAG_NAME = 'bolt_file_modifications'; @@ -46,50 +45,35 @@ const staticModels: ModelInfo[] = [ ]; export let MODEL_LIST: ModelInfo[] = [...staticModels]; - -async function getOllamaModels(): Promise { - try { - const base_url = import.meta.env.OLLAMA_API_BASE_URL || "http://localhost:11434"; - const response = await fetch(`${base_url}/api/tags`); - const data = await response.json() as OllamaApiResponse; - - return data.models.map((model: OllamaModel) => ({ - name: model.name, - label: `${model.name} (${model.details.parameter_size})`, - provider: 'Ollama', - })); - } catch (e) { - return []; +export function hasModel(modelName:string,provider:string): boolean { + for (const model of MODEL_LIST) { + if ( model.provider === provider && model.name === modelName) { + return true; + } } + return false } +export const IS_SERVER = typeof window === 'undefined'; -async function getOpenAILikeModels(): Promise { - try { - const base_url =import.meta.env.OPENAI_LIKE_API_BASE_URL || ""; - if (!base_url) { - return []; - } - const api_key = import.meta.env.OPENAI_LIKE_API_KEY ?? ""; - const response = await fetch(`${base_url}/models`, { - headers: { - Authorization: `Bearer ${api_key}`, - } - }); - const res = await response.json() as any; - return res.data.map((model: any) => ({ - name: model.id, - label: model.id, - provider: 'OpenAILike', - })); - }catch (e) { - return [] - } +export function setModelList(models: ModelInfo[]): void { + MODEL_LIST = models; +} +export function getStaticModels(): ModelInfo[] { + return [...staticModels]; } -async function initializeModelList(): Promise { - const ollamaModels = await getOllamaModels(); - const openAiLikeModels = await getOpenAILikeModels(); - MODEL_LIST = [...ollamaModels,...openAiLikeModels, ...staticModels]; +export let isInitialized = false; + +export async function initializeModelList(): Promise { + if (isInitialized) { + return MODEL_LIST; + } + if (IS_SERVER){ + isInitialized = true; + return MODEL_LIST; + } + isInitialized = true; + const response = await fetch('/api/models'); + MODEL_LIST = (await response.json()) as ModelInfo[]; + return MODEL_LIST; } -initializeModelList().then(); -export { getOllamaModels, getOpenAILikeModels, initializeModelList }; diff --git a/app/utils/tools.ts b/app/utils/tools.ts new file mode 100644 index 000000000..74cbdadb4 --- /dev/null +++ b/app/utils/tools.ts @@ -0,0 +1,74 @@ +import type { ModelInfo, OllamaApiResponse, OllamaModel } from '~/utils/types'; +import { getStaticModels,setModelList } from '~/utils/constants'; +import { getAPIKey, getBaseURL } from '~/lib/.server/llm/api-key'; + + +export let MODEL_LIST: ModelInfo[] = [...getStaticModels()]; + + + +let isInitialized = false; +async function getOllamaModels(env: Env): Promise { + try { + const base_url = getBaseURL(env,"Ollama") ; + const response = await fetch(`${base_url}/api/tags`); + const data = await response.json() as OllamaApiResponse; + return data.models.map((model: OllamaModel) => ({ + name: model.name, + label: `${model.name} (${model.details.parameter_size})`, + provider: 'Ollama', + })); + } catch (e) { + return [{ + name: "Empty", + label: "Empty", + provider: "Ollama" + }]; + } +} + +async function getOpenAILikeModels(env: Env): Promise { + try { + const base_url = getBaseURL(env,"OpenAILike") ; + if (!base_url) { + return [{ + name: "Empty", + label: "Empty", + provider: "OpenAILike" + }]; + } + const api_key = getAPIKey(env,"OpenAILike") ?? ""; + const response = await fetch(`${base_url}/models`, { + headers: { + Authorization: `Bearer ${api_key}`, + } + }); + const res = await response.json() as any; + return res.data.map((model: any) => ({ + name: model.id, + label: model.id, + provider: 'OpenAILike', + })); + }catch (e) { + return [{ + name: "Empty", + label: "Empty", + provider: "OpenAILike" + }]; + } +} + + +async function initializeModelList(env: Env): Promise { + if (isInitialized) { + return MODEL_LIST; + } + isInitialized = true; + const ollamaModels = await getOllamaModels(env); + const openAiLikeModels = await getOpenAILikeModels(env); + MODEL_LIST = [...getStaticModels(), ...ollamaModels, ...openAiLikeModels]; + setModelList(MODEL_LIST); + return MODEL_LIST; +} + +export { getOllamaModels, getOpenAILikeModels, initializeModelList }; diff --git a/package.json b/package.json index edb2b8dad..1f9f0e845 100644 --- a/package.json +++ b/package.json @@ -6,13 +6,14 @@ "sideEffects": false, "type": "module", "scripts": { - "deploy": "npm run build && wrangler pages deploy", + "deploy": "pnpm run build && wrangler pages deploy", + "dev:deploy": "pnpm run build && wrangler pages dev", "build": "remix vite:build", "dev": "remix vite:dev", "test": "vitest --run", "test:watch": "vitest", "lint": "eslint --cache --cache-location ./node_modules/.cache/eslint .", - "lint:fix": "npm run lint -- --fix", + "lint:fix": "pnpm run lint -- --fix", "start": "bindings=$(./bindings.sh) && wrangler pages dev ./build/client $bindings --ip 0.0.0.0 --port 3000", "typecheck": "tsc", "typegen": "wrangler types",