Skip to content

Commit

Permalink
chore: Update constants.ts and fix model name bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ZerxZ committed Oct 26, 2024
1 parent 8e7220e commit 1d8d6f1
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 63 deletions.
29 changes: 13 additions & 16 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton';
import { Workbench } from '~/components/workbench/Workbench.client';
import { classNames } from '~/utils/classNames';
import { MODEL_LIST, DEFAULT_PROVIDER } from '~/utils/constants';
import { DEFAULT_PROVIDER, initializeModelList } from '~/utils/constants';
import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client';
import { useState } from 'react';
Expand All @@ -21,44 +21,41 @@ const EXAMPLE_PROMPTS = [
{ text: 'Make a space invaders game' },
{ text: 'How do I center a div?' },
];

const MODEL_LIST= await initializeModelList();
const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))]

const ModelSelector = ({ model, setModel, modelList, providerList }) => {
function ModelSelector (args) {
const {model, setModel, modelList,providerList} = args;
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
return (
<div className="mb-2">
<select
value={provider}
onChange={(e) => {
setProvider(e.target.value);
const firstModel = [...modelList].find(m => m.provider == e.target.value);
const firstModel = [...modelList].find((m) => m.provider == e.target.value);
setModel(firstModel ? firstModel.name : '');
}}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{providerList.map((provider) => (
{providerList.map(provider => (
<option key={provider} value={provider}>
{provider}
</option>
))}
<option key="Ollama" value="Ollama">
Ollama
</option>
<option key="OpenAILike" value="OpenAILike">
OpenAILike
</option>
</select>
<select
value={model}
onChange={(e) => setModel(e.target.value)}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{[...modelList].filter(e => e.provider == provider && e.name).map((modelOption) => (
<option key={modelOption.name} value={modelOption.name}>
{modelOption.label}
</option>
))}
{[...modelList]
.filter((e) => e.provider == provider && e.name)
.map((modelOption) => (
<option key={modelOption.name} value={modelOption.name}>
{modelOption.label}
</option>
))}
</select>
</div>
);
Expand Down
3 changes: 1 addition & 2 deletions app/entry.server.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { renderToReadableStream } from 'react-dom/server';
import { renderHeadToString } from 'remix-island';
import { Head } from './root';
import { themeStore } from '~/lib/stores/theme';
import { initializeModelList } from '~/utils/constants';

export default async function handleRequest(
request: Request,
Expand All @@ -14,7 +13,7 @@ export default async function handleRequest(
remixContext: EntryContext,
_loadContext: AppLoadContext,
) {
await initializeModelList();


const readable = await renderToReadableStream(<RemixServer context={remixContext} url={request.url} />, {
signal: request.signal,
Expand Down
5 changes: 3 additions & 2 deletions app/routes/api.models.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { json } from '@remix-run/cloudflare';
import { MODEL_LIST } from '~/utils/constants';
import { initializeModelList } from '~/utils/tools';

export async function loader() {
return json(MODEL_LIST);
const modelList = await initializeModelList();
return json(modelList);
}
63 changes: 20 additions & 43 deletions app/utils/constants.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types';

import type { ModelInfo } from './types';
export const WORK_DIR_NAME = 'project';
export const WORK_DIR = `/home/${WORK_DIR_NAME}`;
export const MODIFICATIONS_TAG_NAME = 'bolt_file_modifications';
Expand Down Expand Up @@ -47,49 +46,27 @@ const staticModels: ModelInfo[] = [

export let MODEL_LIST: ModelInfo[] = [...staticModels];

async function getOllamaModels(): Promise<ModelInfo[]> {
try {
const base_url = import.meta.env.OLLAMA_API_BASE_URL || "http://localhost:11434";
const response = await fetch(`${base_url}/api/tags`);
const data = await response.json() as OllamaApiResponse;
export const IS_SERVER = typeof window === 'undefined';

return data.models.map((model: OllamaModel) => ({
name: model.name,
label: `${model.name} (${model.details.parameter_size})`,
provider: 'Ollama',
}));
} catch (e) {
return [];
}
export function setModelList(models: ModelInfo[]): void {
MODEL_LIST = models;
}

async function getOpenAILikeModels(): Promise<ModelInfo[]> {
try {
const base_url =import.meta.env.OPENAI_LIKE_API_BASE_URL || "";
if (!base_url) {
return [];
}
const api_key = import.meta.env.OPENAI_LIKE_API_KEY ?? "";
const response = await fetch(`${base_url}/models`, {
headers: {
Authorization: `Bearer ${api_key}`,
}
});
const res = await response.json() as any;
return res.data.map((model: any) => ({
name: model.id,
label: model.id,
provider: 'OpenAILike',
}));
}catch (e) {
return []
}

export function getStaticModels(): ModelInfo[] {
return [...staticModels];
}
async function initializeModelList(): Promise<void> {
const ollamaModels = await getOllamaModels();
const openAiLikeModels = await getOpenAILikeModels();
MODEL_LIST = [...ollamaModels,...openAiLikeModels, ...staticModels];
let isInitialized = false;

export async function initializeModelList(): Promise<ModelInfo[]> {
if (isInitialized ) {
return MODEL_LIST;
}
if (IS_SERVER){
isInitialized = true;
return MODEL_LIST;
}
isInitialized = true;
const response = await fetch('/api/models');
MODEL_LIST = (await response.json()) as ModelInfo[];
return MODEL_LIST;
}
initializeModelList().then();
export { getOllamaModels, getOpenAILikeModels, initializeModelList };
106 changes: 106 additions & 0 deletions app/utils/tools.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import type { ModelInfo, OllamaApiResponse, OllamaModel } from '~/utils/types';
import { getStaticModels,setModelList } from '~/utils/constants';
import { env } from 'node:process';


export let MODEL_LIST: ModelInfo[] = [...getStaticModels()];

export function getAPIKey(provider: string) {
switch (provider) {
case 'Anthropic':
return env.ANTHROPIC_API_KEY;
case 'OpenAI':
return env.OPENAI_API_KEY;
case 'Google':
return env.GOOGLE_GENERATIVE_AI_API_KEY;
case 'Groq':
return env.GROQ_API_KEY;
case 'OpenRouter':
return env.OPEN_ROUTER_API_KEY;
case 'Deepseek':
return env.DEEPSEEK_API_KEY;
case 'Mistral':
return env.MISTRAL_API_KEY;
case "OpenAILike":
return import.meta.env.OPENAI_LIKE_API_KEY || env.OPENAI_LIKE_API_KEY;
default:
return "";
}
}
export function getBaseURL( provider: string){
switch (provider) {
case 'OpenAILike':
return import.meta.env.OPENAI_LIKE_API_BASE_URL || env.OPENAI_LIKE_API_BASE_URL || "";
case 'Ollama':
return import.meta.env.OLLAMA_API_BASE_URL || env.OLLAMA_API_BASE_URL || "http://localhost:11434";
default:
return "";
}
}

let isInitialized = false;
async function getOllamaModels(): Promise<ModelInfo[]> {
try {
const base_url = getBaseURL("Ollama") ;
const response = await fetch(`${base_url}/api/tags`);
const data = await response.json() as OllamaApiResponse;
return data.models.map((model: OllamaModel) => ({
name: model.name,
label: `${model.name} (${model.details.parameter_size})`,
provider: 'Ollama',
}));
} catch (e) {
return [{
name: "Empty",
label: "Empty",
provider: "Ollama"
}];
}
}

async function getOpenAILikeModels(): Promise<ModelInfo[]> {
try {
const base_url = getBaseURL("OpenAILike") ;
if (!base_url) {
return [{
name: "Empty",
label: "Empty",
provider: "OpenAILike"
}];
}
const api_key = getAPIKey("OpenAILike") ?? "";
const response = await fetch(`${base_url}/models`, {
headers: {
Authorization: `Bearer ${api_key}`,
}
});
const res = await response.json() as any;
return res.data.map((model: any) => ({
name: model.id,
label: model.id,
provider: 'OpenAILike',
}));
}catch (e) {
console.error(e);
return [{
name: "Empty",
label: "Empty",
provider: "OpenAILike"
}];
}
}


async function initializeModelList(): Promise<ModelInfo[]> {
if (isInitialized) {
return MODEL_LIST;
}
isInitialized = true;
const ollamaModels = await getOllamaModels();
const openAiLikeModels = await getOpenAILikeModels();
MODEL_LIST = [...getStaticModels(), ...ollamaModels, ...openAiLikeModels];
setModelList(MODEL_LIST);
return MODEL_LIST;
}

export { getOllamaModels, getOpenAILikeModels, initializeModelList };

0 comments on commit 1d8d6f1

Please sign in to comment.