Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update constants.ts and fix model name bug #89

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton';
import { Workbench } from '~/components/workbench/Workbench.client';
import { classNames } from '~/utils/classNames';
import { MODEL_LIST, DEFAULT_PROVIDER } from '~/utils/constants';
import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client';
import { useState } from 'react';

import styles from './BaseChat.module.scss';
import type { ModelInfo } from '~/utils/types';

const EXAMPLE_PROMPTS = [
{ text: 'Build a todo app in React using Tailwind' },
Expand All @@ -22,36 +21,33 @@ const EXAMPLE_PROMPTS = [
{ text: 'How do I center a div?' },
];

const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))]

const ModelSelector = ({ model, setModel, modelList, providerList }) => {
const [provider, setProvider] = useState(DEFAULT_PROVIDER);

function ModelSelector({ model, setModel ,provider,setProvider,modelList,providerList}) {




return (
<div className="mb-2">
<select
value={provider}
onChange={(e) => {
setProvider(e.target.value);
const firstModel = [...modelList].find(m => m.provider == e.target.value);
const firstModel = modelList.find((m) => m.provider === e.target.value);
setModel(firstModel ? firstModel.name : '');
}}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{providerList.map((provider) => (
<option key={provider} value={provider}>
{provider}
</option>
))}
<option key="Ollama" value="Ollama">
Ollama
</option>
<option key="OpenAILike" value="OpenAILike">
OpenAILike
</option>
{providerList.map(providerName=>( <option key={providerName} value={providerName}>
{providerName}
</option>))}
</select>
<select
value={model}
onChange={(e) => setModel(e.target.value)}
onChange={(e) => {
setModel(e.target.value)
}}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{[...modelList].filter(e => e.provider == provider && e.name).map((modelOption) => (
Expand All @@ -62,7 +58,7 @@ const ModelSelector = ({ model, setModel, modelList, providerList }) => {
</select>
</div>
);
};
}

const TEXTAREA_MIN_HEIGHT = 76;

Expand All @@ -77,8 +73,12 @@ interface BaseChatProps {
enhancingPrompt?: boolean;
promptEnhanced?: boolean;
input?: string;
model: string;
setModel: (model: string) => void;
model?: string;
setModel?: (model: string) => void;
provider?: string;
setProvider?: (provider: string) => void;
modelList?: ModelInfo[];
providerList?: string[];
handleStop?: () => void;
sendMessage?: (event: React.UIEvent, messageInput?: string) => void;
handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
Expand All @@ -100,6 +100,10 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
input = '',
model,
setModel,
provider,
setProvider,
modelList,
providerList,
sendMessage,
handleInputChange,
enhancePrompt,
Expand All @@ -108,7 +112,6 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
ref,
) => {
const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;

return (
<div
ref={ref}
Expand Down Expand Up @@ -156,7 +159,9 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
<ModelSelector
model={model}
setModel={setModel}
modelList={MODEL_LIST}
provider={provider}
setProvider={setProvider}
modelList={modelList}
providerList={providerList}
/>
<div
Expand All @@ -174,7 +179,6 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
}

event.preventDefault();

sendMessage?.(event);
}
}}
Expand Down
23 changes: 20 additions & 3 deletions app/components/chat/Chat.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { useChatHistory } from '~/lib/persistence';
import { chatStore } from '~/lib/stores/chat';
import { workbenchStore } from '~/lib/stores/workbench';
import { fileModificationsToHTML } from '~/utils/diff';
import { DEFAULT_MODEL } from '~/utils/constants';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, initializeModelList, isInitialized, MODEL_LIST } from '~/utils/constants';
import { cubicEasingFn } from '~/utils/easings';
import { createScopedLogger, renderLogger } from '~/utils/logger';
import { BaseChat } from './BaseChat';
Expand Down Expand Up @@ -74,6 +74,19 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp

const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
const [model, setModel] = useState(DEFAULT_MODEL);
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const [modelList, setModelList] = useState(MODEL_LIST);
const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]);
const initialize = async () => {
if (!isInitialized) {
const models= await initializeModelList();
const modelList = models;
const providerList = [...new Set([...models.map((m) => m.provider),"Ollama","OpenAILike"])];
setModelList(modelList);
setProviderList(providerList);
}
};
initialize();

const { showChat } = useStore(chatStore);

Expand Down Expand Up @@ -182,15 +195,15 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
* manually reset the input and we'd have to manually pass in file attachments. However, those
* aren't relevant here.
*/
append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${diff}\n\n${_input}` });

/**
* After sending a new message we reset all modifications since the model
* should now be aware of all the changes.
*/
workbenchStore.resetAllFileModifications();
} else {
append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${_input}` });
}

setInput('');
Expand All @@ -215,6 +228,10 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
sendMessage={sendMessage}
model={model}
setModel={setModel}
provider={provider}
setProvider={setProvider}
modelList={modelList}
providerList={providerList}
messageRef={messageRef}
scrollRef={scrollRef}
handleInputChange={handleInputChange}
Expand Down
3 changes: 1 addition & 2 deletions app/entry.server.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { renderToReadableStream } from 'react-dom/server';
import { renderHeadToString } from 'remix-island';
import { Head } from './root';
import { themeStore } from '~/lib/stores/theme';
import { initializeModelList } from '~/utils/constants';

export default async function handleRequest(
request: Request,
Expand All @@ -14,7 +13,7 @@ export default async function handleRequest(
remixContext: EntryContext,
_loadContext: AppLoadContext,
) {
await initializeModelList();


const readable = await renderToReadableStream(<RemixServer context={remixContext} url={request.url} />, {
signal: request.signal,
Expand Down
1 change: 0 additions & 1 deletion app/lib/.server/llm/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { createOpenAI } from '@ai-sdk/openai';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { ollama } from 'ollama-ai-provider';
import { createOpenRouter } from "@openrouter/ai-sdk-provider";
import { mistral } from '@ai-sdk/mistral';
import { createMistral } from '@ai-sdk/mistral';

export function getAnthropicModel(apiKey: string, model: string) {
Expand Down
37 changes: 21 additions & 16 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { streamText as _streamText, convertToCoreMessages } from 'ai';
import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts';
import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, hasModel } from '~/utils/constants';

interface ToolResult<Name extends string, Args, Result> {
toolCallId: string;
Expand All @@ -25,42 +25,47 @@ export type Messages = Message[];
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;

function extractModelFromMessage(message: Message): { model: string; content: string } {
const modelRegex = /^\[Model: (.*?)\]\n\n/;
const modelRegex = /^\[Model: (.*?)Provider: (.*?)\]\n\n/;
const match = message.content.match(modelRegex);

if (match) {
const model = match[1];
const content = message.content.replace(modelRegex, '');
return { model, content };
if (!match) {
return { model: DEFAULT_MODEL, content: message.content,provider: DEFAULT_PROVIDER };
}

const [_,model,provider] = match;
const content = message.content.replace(modelRegex, '');
return { model, content ,provider};
// Default model if not specified
return { model: DEFAULT_MODEL, content: message.content };

}

export function streamText(messages: Messages, env: Env, options?: StreamingOptions) {
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;
const lastMessage = messages.findLast((message) => message.role === 'user');
if (lastMessage) {
const { model, provider } = extractModelFromMessage(lastMessage);
if (hasModel(model, provider)) {
currentModel = model;
currentProvider = provider;
}
}
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, content } = extractModelFromMessage(message);
if (model && MODEL_LIST.find((m) => m.name === model)) {
currentModel = model; // Update the current model
}
const { content } = extractModelFromMessage(message);
return { ...message, content };
}
return message;
});

const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER;

const coreMessages = convertToCoreMessages(processedMessages);
return _streamText({
model: getModel(provider, currentModel, env),
model: getModel(currentProvider, currentModel, env),
system: getSystemPrompt(),
maxTokens: MAX_TOKENS,
// headers: {
// 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
// },
messages: convertToCoreMessages(processedMessages),
messages: coreMessages,
...options,
});
}
8 changes: 7 additions & 1 deletion app/routes/_index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import { ClientOnly } from 'remix-utils/client-only';
import { BaseChat } from '~/components/chat/BaseChat';
import { Chat } from '~/components/chat/Chat.client';
import { Header } from '~/components/header/Header';
import { useState } from 'react';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST } from '~/utils/constants';

export const meta: MetaFunction = () => {
return [{ title: 'Bolt' }, { name: 'description', content: 'Talk with Bolt, an AI assistant from StackBlitz' }];
Expand All @@ -11,10 +13,14 @@ export const meta: MetaFunction = () => {
export const loader = () => json({});

export default function Index() {
const [model, setModel] = useState(DEFAULT_MODEL);
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const [modelList, setModelList] = useState(MODEL_LIST);
const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]);
return (
<div className="flex flex-col h-full w-full">
<Header />
<ClientOnly fallback={<BaseChat />}>{() => <Chat />}</ClientOnly>
<ClientOnly fallback={<BaseChat model={model} modelList={modelList} provider={provider} providerList={providerList} setModel={setModel} setProvider={setProvider}/>}>{() => <Chat />}</ClientOnly>
</div>
);
}
10 changes: 6 additions & 4 deletions app/routes/api.models.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { json } from '@remix-run/cloudflare';
import { MODEL_LIST } from '~/utils/constants';
import { json, type LoaderFunctionArgs } from '@remix-run/cloudflare';
import { initializeModelList } from '~/utils/tools';

export async function loader() {
return json(MODEL_LIST);
export async function loader({context}: LoaderFunctionArgs) {
const { env } = context.cloudflare;
const modelList = await initializeModelList(env);
return json(modelList);
}
Loading
Loading