diff --git a/app/api/common.ts b/app/api/common.ts index b4c792d6ff0..495a12ccdbb 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -1,8 +1,8 @@ import { NextRequest, NextResponse } from "next/server"; import { getServerSideConfig } from "../config/server"; import { OPENAI_BASE_URL, ServiceProvider } from "../constant"; -import { isModelAvailableInServer } from "../utils/model"; import { cloudflareAIGatewayUrl } from "../utils/cloudflare"; +import { getModelProvider, isModelAvailableInServer } from "../utils/model"; const serverConfig = getServerSideConfig(); @@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) { .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName)) .forEach((m) => { const [fullName, displayName] = m.split("="); - const [_, providerName] = fullName.split("@"); + const [_, providerName] = getModelProvider(fullName); if (providerName === "azure" && !displayName) { const [_, deployId] = (serverConfig?.azureUrl ?? "").split( "deployments/", diff --git a/app/components/chat.tsx b/app/components/chat.tsx index c5deeefa5c4..82d6c6e398a 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio"; import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts"; import { isEmpty } from "lodash-es"; +import { getModelProvider } from "../utils/model"; const localStorage = safeLocalStorage(); @@ -648,7 +649,7 @@ export function ChatActions(props: { onClose={() => setShowModelSelector(false)} onSelection={(s) => { if (s.length === 0) return; - const [model, providerName] = s[0].split("@"); + const [model, providerName] = getModelProvider(s[0]); chatStore.updateTargetSession(session, (session) => { session.mask.modelConfig.model = model as ModelType; session.mask.modelConfig.providerName = diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index f2297e10b49..e845bfeac7a 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib"; import { useAllModels } from "../utils/hooks"; import { groupBy } from "lodash-es"; import styles from "./model-config.module.scss"; +import { getModelProvider } from "../utils/model"; export function ModelConfigList(props: { modelConfig: ModelConfig; @@ -28,7 +29,9 @@ export function ModelConfigList(props: { value={value} align="left" onChange={(e) => { - const [model, providerName] = e.currentTarget.value.split("@"); + const [model, providerName] = getModelProvider( + e.currentTarget.value, + ); props.updateConfig((config) => { config.model = ModalConfigValidator.model(model); config.providerName = providerName as ServiceProvider; @@ -247,7 +250,9 @@ export function ModelConfigList(props: { aria-label={Locale.Settings.CompressModel.Title} value={compressModelValue} onChange={(e) => { - const [model, providerName] = e.currentTarget.value.split("@"); + const [model, providerName] = getModelProvider( + e.currentTarget.value, + ); props.updateConfig((config) => { config.compressModel = ModalConfigValidator.model(model); config.compressProviderName = providerName as ServiceProvider; diff --git a/app/store/access.ts b/app/store/access.ts index 3b0e6357bc1..4796b2fe84e 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client"; import { createPersistStore } from "../utils/store"; import { ensure } from "../utils/clone"; import { DEFAULT_CONFIG } from "./config"; +import { getModelProvider } from "../utils/model"; let fetchState = 0; // 0 not fetch, 1 fetching, 2 done @@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore( .then((res) => { const defaultModel = res.defaultModel ?? ""; if (defaultModel !== "") { - const [model, providerName] = defaultModel.split("@"); + const [model, providerName] = getModelProvider(defaultModel); DEFAULT_CONFIG.modelConfig.model = model; - DEFAULT_CONFIG.modelConfig.providerName = providerName; + DEFAULT_CONFIG.modelConfig.providerName = providerName as any; } return res; diff --git a/app/utils/model.ts b/app/utils/model.ts index 0b62b53be09..a1b7df1b61e 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -37,6 +37,17 @@ const sortModelTable = (models: ReturnType) => } }); +/** + * get model name and provider from a formatted string, + * e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google` + * @param modelWithProvider model name with provider separated by last `@` char, + * @returns [model, provider] tuple, if no `@` char found, provider is undefined + */ +export function getModelProvider(modelWithProvider: string): [string, string?] { + const [model, provider] = modelWithProvider.split(/@(?!.*@)/); + return [model, provider]; +} + export function collectModelTable( models: readonly LLMModel[], customModels: string, @@ -79,10 +90,10 @@ export function collectModelTable( ); } else { // 1. find model by name, and set available value - const [customModelName, customProviderName] = name.split("@"); + const [customModelName, customProviderName] = getModelProvider(name); let count = 0; for (const fullName in modelTable) { - const [modelName, providerName] = fullName.split("@"); + const [modelName, providerName] = getModelProvider(fullName); if ( customModelName == modelName && (customProviderName === undefined || @@ -102,7 +113,7 @@ export function collectModelTable( } // 2. if model not exists, create new model with available value if (count === 0) { - let [customModelName, customProviderName] = name.split("@"); + let [customModelName, customProviderName] = getModelProvider(name); const provider = customProvider( customProviderName || customModelName, ); @@ -139,7 +150,7 @@ export function collectModelTableWithDefaultModel( for (const key of Object.keys(modelTable)) { if ( modelTable[key].available && - key.split("@").shift() == defaultModel + getModelProvider(key)[0] == defaultModel ) { modelTable[key].isDefault = true; break; diff --git a/test/model-provider.test.ts b/test/model-provider.test.ts new file mode 100644 index 00000000000..41f14be026c --- /dev/null +++ b/test/model-provider.test.ts @@ -0,0 +1,31 @@ +import { getModelProvider } from "../app/utils/model"; + +describe("getModelProvider", () => { + test("should return model and provider when input contains '@'", () => { + const input = "model@provider"; + const [model, provider] = getModelProvider(input); + expect(model).toBe("model"); + expect(provider).toBe("provider"); + }); + + test("should return model and undefined provider when input does not contain '@'", () => { + const input = "model"; + const [model, provider] = getModelProvider(input); + expect(model).toBe("model"); + expect(provider).toBeUndefined(); + }); + + test("should handle multiple '@' characters correctly", () => { + const input = "model@provider@extra"; + const [model, provider] = getModelProvider(input); + expect(model).toBe("model@provider"); + expect(provider).toBe("extra"); + }); + + test("should return empty strings when input is empty", () => { + const input = ""; + const [model, provider] = getModelProvider(input); + expect(model).toBe(""); + expect(provider).toBeUndefined(); + }); +});