Skip to content

Commit

Permalink
Update constants.ts and fix model name bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ZerxZ committed Oct 26, 2024
1 parent 8e7220e commit 5700108
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 96 deletions.
52 changes: 28 additions & 24 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton';
import { Workbench } from '~/components/workbench/Workbench.client';
import { classNames } from '~/utils/classNames';
import { MODEL_LIST, DEFAULT_PROVIDER } from '~/utils/constants';
import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client';
import { useState } from 'react';

import styles from './BaseChat.module.scss';
import type { ModelInfo } from '~/utils/types';

const EXAMPLE_PROMPTS = [
{ text: 'Build a todo app in React using Tailwind' },
Expand All @@ -22,36 +21,33 @@ const EXAMPLE_PROMPTS = [
{ text: 'How do I center a div?' },
];

const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))]

const ModelSelector = ({ model, setModel, modelList, providerList }) => {
const [provider, setProvider] = useState(DEFAULT_PROVIDER);

function ModelSelector({ model, setModel ,provider,setProvider,modelList,providerList}) {




return (
<div className="mb-2">
<select
value={provider}
onChange={(e) => {
setProvider(e.target.value);
const firstModel = [...modelList].find(m => m.provider == e.target.value);
const firstModel = modelList.find((m) => m.provider === e.target.value);
setModel(firstModel ? firstModel.name : '');
}}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{providerList.map((provider) => (
<option key={provider} value={provider}>
{provider}
</option>
))}
<option key="Ollama" value="Ollama">
Ollama
</option>
<option key="OpenAILike" value="OpenAILike">
OpenAILike
</option>
{providerList.map(providerName=>( <option key={providerName} value={providerName}>
{providerName}
</option>))}
</select>
<select
value={model}
onChange={(e) => setModel(e.target.value)}
onChange={(e) => {
setModel(e.target.value)
}}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{[...modelList].filter(e => e.provider == provider && e.name).map((modelOption) => (
Expand All @@ -62,7 +58,7 @@ const ModelSelector = ({ model, setModel, modelList, providerList }) => {
</select>
</div>
);
};
}

const TEXTAREA_MIN_HEIGHT = 76;

Expand All @@ -77,8 +73,12 @@ interface BaseChatProps {
enhancingPrompt?: boolean;
promptEnhanced?: boolean;
input?: string;
model: string;
setModel: (model: string) => void;
model?: string;
setModel?: (model: string) => void;
provider?: string;
setProvider?: (provider: string) => void;
modelList?: ModelInfo[];
providerList?: string[];
handleStop?: () => void;
sendMessage?: (event: React.UIEvent, messageInput?: string) => void;
handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
Expand All @@ -100,6 +100,10 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
input = '',
model,
setModel,
provider,
setProvider,
modelList,
providerList,
sendMessage,
handleInputChange,
enhancePrompt,
Expand All @@ -108,7 +112,6 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
ref,
) => {
const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;

return (
<div
ref={ref}
Expand Down Expand Up @@ -156,7 +159,9 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
<ModelSelector
model={model}
setModel={setModel}
modelList={MODEL_LIST}
provider={provider}
setProvider={setProvider}
modelList={modelList}
providerList={providerList}
/>
<div
Expand All @@ -174,7 +179,6 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
}

event.preventDefault();

sendMessage?.(event);
}
}}
Expand Down
23 changes: 20 additions & 3 deletions app/components/chat/Chat.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { useChatHistory } from '~/lib/persistence';
import { chatStore } from '~/lib/stores/chat';
import { workbenchStore } from '~/lib/stores/workbench';
import { fileModificationsToHTML } from '~/utils/diff';
import { DEFAULT_MODEL } from '~/utils/constants';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, initializeModelList, isInitialized, MODEL_LIST } from '~/utils/constants';
import { cubicEasingFn } from '~/utils/easings';
import { createScopedLogger, renderLogger } from '~/utils/logger';
import { BaseChat } from './BaseChat';
Expand Down Expand Up @@ -74,6 +74,19 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp

const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
const [model, setModel] = useState(DEFAULT_MODEL);
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const [modelList, setModelList] = useState(MODEL_LIST);
const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]);
const initialize = async () => {
if (!isInitialized) {
const models= await initializeModelList();
const modelList = models;
const providerList = [...new Set([...models.map((m) => m.provider),"Ollama","OpenAILike"])];
setModelList(modelList);
setProviderList(providerList);
}
};
initialize();

const { showChat } = useStore(chatStore);

Expand Down Expand Up @@ -182,15 +195,15 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
* manually reset the input and we'd have to manually pass in file attachments. However, those
* aren't relevant here.
*/
append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${diff}\n\n${_input}` });

/**
* After sending a new message we reset all modifications since the model
* should now be aware of all the changes.
*/
workbenchStore.resetAllFileModifications();
} else {
append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${_input}` });
}

setInput('');
Expand All @@ -215,6 +228,10 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
sendMessage={sendMessage}
model={model}
setModel={setModel}
provider={provider}
setProvider={setProvider}
modelList={modelList}
providerList={providerList}
messageRef={messageRef}
scrollRef={scrollRef}
handleInputChange={handleInputChange}
Expand Down
3 changes: 1 addition & 2 deletions app/entry.server.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { renderToReadableStream } from 'react-dom/server';
import { renderHeadToString } from 'remix-island';
import { Head } from './root';
import { themeStore } from '~/lib/stores/theme';
import { initializeModelList } from '~/utils/constants';

export default async function handleRequest(
request: Request,
Expand All @@ -14,7 +13,7 @@ export default async function handleRequest(
remixContext: EntryContext,
_loadContext: AppLoadContext,
) {
await initializeModelList();


const readable = await renderToReadableStream(<RemixServer context={remixContext} url={request.url} />, {
signal: request.signal,
Expand Down
1 change: 0 additions & 1 deletion app/lib/.server/llm/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { createOpenAI } from '@ai-sdk/openai';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { ollama } from 'ollama-ai-provider';
import { createOpenRouter } from "@openrouter/ai-sdk-provider";
import { mistral } from '@ai-sdk/mistral';
import { createMistral } from '@ai-sdk/mistral';

export function getAnthropicModel(apiKey: string, model: string) {
Expand Down
37 changes: 21 additions & 16 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { streamText as _streamText, convertToCoreMessages } from 'ai';
import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts';
import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, hasModel } from '~/utils/constants';

interface ToolResult<Name extends string, Args, Result> {
toolCallId: string;
Expand All @@ -25,42 +25,47 @@ export type Messages = Message[];
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;

function extractModelFromMessage(message: Message): { model: string; content: string } {
const modelRegex = /^\[Model: (.*?)\]\n\n/;
const modelRegex = /^\[Model: (.*?)Provider: (.*?)\]\n\n/;
const match = message.content.match(modelRegex);

if (match) {
const model = match[1];
const content = message.content.replace(modelRegex, '');
return { model, content };
if (!match) {
return { model: DEFAULT_MODEL, content: message.content,provider: DEFAULT_PROVIDER };
}

const [_,model,provider] = match;
const content = message.content.replace(modelRegex, '');
return { model, content ,provider};
// Default model if not specified
return { model: DEFAULT_MODEL, content: message.content };

}

export function streamText(messages: Messages, env: Env, options?: StreamingOptions) {
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;
const lastMessage = messages.findLast((message) => message.role === 'user');
if (lastMessage) {
const { model, provider } = extractModelFromMessage(lastMessage);
if (hasModel(model, provider)) {
currentModel = model;
currentProvider = provider;
}
}
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, content } = extractModelFromMessage(message);
if (model && MODEL_LIST.find((m) => m.name === model)) {
currentModel = model; // Update the current model
}
const { content } = extractModelFromMessage(message);
return { ...message, content };
}
return message;
});

const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER;

const coreMessages = convertToCoreMessages(processedMessages);
return _streamText({
model: getModel(provider, currentModel, env),
model: getModel(currentProvider, currentModel, env),
system: getSystemPrompt(),
maxTokens: MAX_TOKENS,
// headers: {
// 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
// },
messages: convertToCoreMessages(processedMessages),
messages: coreMessages,
...options,
});
}
8 changes: 7 additions & 1 deletion app/routes/_index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import { ClientOnly } from 'remix-utils/client-only';
import { BaseChat } from '~/components/chat/BaseChat';
import { Chat } from '~/components/chat/Chat.client';
import { Header } from '~/components/header/Header';
import { useState } from 'react';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST } from '~/utils/constants';

export const meta: MetaFunction = () => {
return [{ title: 'Bolt' }, { name: 'description', content: 'Talk with Bolt, an AI assistant from StackBlitz' }];
Expand All @@ -11,10 +13,14 @@ export const meta: MetaFunction = () => {
export const loader = () => json({});

export default function Index() {
const [model, setModel] = useState(DEFAULT_MODEL);
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const [modelList, setModelList] = useState(MODEL_LIST);
const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]);
return (
<div className="flex flex-col h-full w-full">
<Header />
<ClientOnly fallback={<BaseChat />}>{() => <Chat />}</ClientOnly>
<ClientOnly fallback={<BaseChat model={model} modelList={modelList} provider={provider} providerList={providerList} setModel={setModel} setProvider={setProvider}/>}>{() => <Chat />}</ClientOnly>
</div>
);
}
10 changes: 6 additions & 4 deletions app/routes/api.models.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { json } from '@remix-run/cloudflare';
import { MODEL_LIST } from '~/utils/constants';
import { json, type LoaderFunctionArgs } from '@remix-run/cloudflare';
import { initializeModelList } from '~/utils/tools';

export async function loader() {
return json(MODEL_LIST);
export async function loader({context}: LoaderFunctionArgs) {
const { env } = context.cloudflare;
const modelList = await initializeModelList(env);
return json(modelList);
}
Loading

0 comments on commit 5700108

Please sign in to comment.