Skip to content

Commit

Permalink
Merge pull request ChatGPTNextWeb#4930 from ConnectAI-E/feature-azure
Browse files Browse the repository at this point in the history
support azure deployment name
  • Loading branch information
Dogtiti authored Jul 6, 2024
2 parents 70907ea + 6dc4844 commit 2d1f522
Show file tree
Hide file tree
Showing 17 changed files with 204 additions and 95 deletions.
2 changes: 1 addition & 1 deletion app/api/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
break;
case ModelProvider.GPT:
default:
if (serverConfig.isAzure) {
if (req.nextUrl.pathname.includes("azure/deployments")) {
systemApiKey = serverConfig.azureApiKey;
} else {
systemApiKey = serverConfig.apiKey;
Expand Down
57 changes: 57 additions & 0 deletions app/api/azure/[...path]/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { getServerSideConfig } from "@/app/config/server";
import { ModelProvider } from "@/app/constant";
import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server";
import { auth } from "../../auth";
import { requestOpenai } from "../../common";

async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
console.log("[Azure Route] params ", params);

if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}

const subpath = params.path.join("/");

const authResult = auth(req, ModelProvider.GPT);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}

try {
return await requestOpenai(req);
} catch (e) {
console.error("[Azure] ", e);
return NextResponse.json(prettyObject(e));
}
}

export const GET = handle;
export const POST = handle;

export const runtime = "edge";
export const preferredRegion = [
"arn1",
"bom1",
"cdg1",
"cle1",
"cpt1",
"dub1",
"fra1",
"gru1",
"hnd1",
"iad1",
"icn1",
"kix1",
"lhr1",
"pdx1",
"sfo1",
"sin1",
"syd1",
];
22 changes: 12 additions & 10 deletions app/api/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ import {
ServiceProvider,
} from "../constant";
import { isModelAvailableInServer } from "../utils/model";
import { makeAzurePath } from "../azure";

const serverConfig = getServerSideConfig();

export async function requestOpenai(req: NextRequest) {
const controller = new AbortController();

const isAzure = req.nextUrl.pathname.includes("azure/deployments");

var authValue,
authHeaderName = "";
if (serverConfig.isAzure) {
if (isAzure) {
authValue =
req.headers
.get("Authorization")
Expand Down Expand Up @@ -56,14 +57,15 @@ export async function requestOpenai(req: NextRequest) {
10 * 60 * 1000,
);

if (serverConfig.isAzure) {
if (!serverConfig.azureApiVersion) {
return NextResponse.json({
error: true,
message: `missing AZURE_API_VERSION in server env vars`,
});
}
path = makeAzurePath(path, serverConfig.azureApiVersion);
if (isAzure) {
const azureApiVersion =
req?.nextUrl?.searchParams?.get("api-version") ||
serverConfig.azureApiVersion;
baseUrl = baseUrl.split("/deployments").shift() as string;
path = `${req.nextUrl.pathname.replaceAll(
"/api/azure/",
"",
)}?api-version=${azureApiVersion}`;
}

const fetchUrl = `${baseUrl}/${path}`;
Expand Down
9 changes: 0 additions & 9 deletions app/azure.ts

This file was deleted.

19 changes: 13 additions & 6 deletions app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export interface RequestMessage {

export interface LLMConfig {
model: string;
providerName?: string;
temperature?: number;
top_p?: number;
stream?: boolean;
Expand All @@ -54,6 +55,7 @@ export interface LLMUsage {

export interface LLMModel {
name: string;
displayName?: string;
available: boolean;
provider: LLMModelProvider;
}
Expand Down Expand Up @@ -160,10 +162,14 @@ export function getHeaders() {
Accept: "application/json",
};
const modelConfig = useChatStore.getState().currentSession().mask.modelConfig;
const isGoogle = modelConfig.model.startsWith("gemini");
const isAzure = accessStore.provider === ServiceProvider.Azure;
const isAnthropic = accessStore.provider === ServiceProvider.Anthropic;
const authHeader = isAzure ? "api-key" : isAnthropic ? 'x-api-key' : "Authorization";
const isGoogle = modelConfig.providerName == ServiceProvider.Google;
const isAzure = modelConfig.providerName === ServiceProvider.Azure;
const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic;
const authHeader = isAzure
? "api-key"
: isAnthropic
? "x-api-key"
: "Authorization";
const apiKey = isGoogle
? accessStore.googleApiKey
: isAzure
Expand All @@ -172,7 +178,8 @@ export function getHeaders() {
? accessStore.anthropicApiKey
: accessStore.openaiApiKey;
const clientConfig = getClientConfig();
const makeBearer = (s: string) => `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
const makeBearer = (s: string) =>
`${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
const validString = (x: string) => x && x.length > 0;

// when using google api in app, not set auth header
Expand All @@ -185,7 +192,7 @@ export function getHeaders() {
validString(accessStore.accessCode)
) {
// access_code must send with header named `Authorization`, will using in auth middleware.
headers['Authorization'] = makeBearer(
headers["Authorization"] = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
Expand Down
53 changes: 41 additions & 12 deletions app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"use client";
// azure and openai, using same models. so using same LLMApi.
import {
ApiPath,
DEFAULT_API_HOST,
DEFAULT_MODELS,
OpenaiPath,
Azure,
REQUEST_TIMEOUT_MS,
ServiceProvider,
} from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { collectModelsWithDefaultModel } from "@/app/utils/model";

import {
ChatOptions,
Expand All @@ -24,7 +27,6 @@ import {
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import { getClientConfig } from "@/app/config/client";
import { makeAzurePath } from "@/app/azure";
import {
getMessageTextContent,
getMessageImages,
Expand Down Expand Up @@ -62,33 +64,31 @@ export class ChatGPTApi implements LLMApi {

let baseUrl = "";

const isAzure = path.includes("deployments");
if (accessStore.useCustomConfig) {
const isAzure = accessStore.provider === ServiceProvider.Azure;

if (isAzure && !accessStore.isValidAzure()) {
throw Error(
"incomplete azure config, please check it in your settings page",
);
}

if (isAzure) {
path = makeAzurePath(path, accessStore.azureApiVersion);
}

baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl;
}

if (baseUrl.length === 0) {
const isApp = !!getClientConfig()?.isApp;
baseUrl = isApp
? DEFAULT_API_HOST + "/proxy" + ApiPath.OpenAI
: ApiPath.OpenAI;
const apiPath = isAzure ? ApiPath.Azure : ApiPath.OpenAI;
baseUrl = isApp ? DEFAULT_API_HOST + "/proxy" + apiPath : apiPath;
}

if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, baseUrl.length - 1);
}
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.OpenAI)) {
if (
!baseUrl.startsWith("http") &&
!isAzure &&
!baseUrl.startsWith(ApiPath.OpenAI)
) {
baseUrl = "https://" + baseUrl;
}

Expand All @@ -113,6 +113,7 @@ export class ChatGPTApi implements LLMApi {
...useChatStore.getState().currentSession().mask.modelConfig,
...{
model: options.config.model,
providerName: options.config.providerName,
},
};

Expand Down Expand Up @@ -140,7 +141,35 @@ export class ChatGPTApi implements LLMApi {
options.onController?.(controller);

try {
const chatPath = this.path(OpenaiPath.ChatPath);
let chatPath = "";
if (modelConfig.providerName === ServiceProvider.Azure) {
// find model, and get displayName as deployName
const { models: configModels, customModels: configCustomModels } =
useAppConfig.getState();
const {
defaultModel,
customModels: accessCustomModels,
useCustomConfig,
} = useAccessStore.getState();
const models = collectModelsWithDefaultModel(
configModels,
[configCustomModels, accessCustomModels].join(","),
defaultModel,
);
const model = models.find(
(model) =>
model.name === modelConfig.model &&
model?.provider?.providerName === ServiceProvider.Azure,
);
chatPath = this.path(
Azure.ChatPath(
(model?.displayName ?? model?.name) as string,
useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
),
);
} else {
chatPath = this.path(OpenaiPath.ChatPath);
}
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
Expand Down
35 changes: 23 additions & 12 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ import {
Path,
REQUEST_TIMEOUT_MS,
UNFINISHED_INPUT,
ServiceProvider,
} from "../constant";
import { Avatar } from "./emoji";
import { ContextPrompts, MaskAvatar, MaskConfig } from "./mask";
Expand Down Expand Up @@ -448,6 +449,9 @@ export function ChatActions(props: {

// switch model
const currentModel = chatStore.currentSession().mask.modelConfig.model;
const currentProviderName =
chatStore.currentSession().mask.modelConfig?.providerName ||
ServiceProvider.OpenAI;
const allModels = useAllModels();
const models = useMemo(() => {
const filteredModels = allModels.filter((m) => m.available);
Expand Down Expand Up @@ -479,13 +483,13 @@ export function ChatActions(props: {
const isUnavaliableModel = !models.some((m) => m.name === currentModel);
if (isUnavaliableModel && models.length > 0) {
// show next model to default model if exist
let nextModel: ModelType = (
models.find((model) => model.isDefault) || models[0]
).name;
chatStore.updateCurrentSession(
(session) => (session.mask.modelConfig.model = nextModel),
);
showToast(nextModel);
let nextModel = models.find((model) => model.isDefault) || models[0];
chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.model = nextModel.name;
session.mask.modelConfig.providerName = nextModel?.provider
?.providerName as ServiceProvider;
});
showToast(nextModel.name);
}
}, [chatStore, currentModel, models]);

Expand Down Expand Up @@ -573,19 +577,26 @@ export function ChatActions(props: {

{showModelSelector && (
<Selector
defaultSelectedValue={currentModel}
defaultSelectedValue={`${currentModel}@${currentProviderName}`}
items={models.map((m) => ({
title: m.displayName,
value: m.name,
title: `${m.displayName}${
m?.provider?.providerName
? "(" + m?.provider?.providerName + ")"
: ""
}`,
value: `${m.name}@${m?.provider?.providerName}`,
}))}
onClose={() => setShowModelSelector(false)}
onSelection={(s) => {
if (s.length === 0) return;
const [model, providerName] = s[0].split("@");
chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.model = s[0] as ModelType;
session.mask.modelConfig.model = model as ModelType;
session.mask.modelConfig.providerName =
providerName as ServiceProvider;
session.mask.syncGlobalConfig = false;
});
showToast(s[0]);
showToast(model);
}}
/>
)}
Expand Down
11 changes: 7 additions & 4 deletions app/components/exporter.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ import { toBlob, toPng } from "html-to-image";
import { DEFAULT_MASK_AVATAR } from "../store/mask";

import { prettyObject } from "../utils/format";
import { EXPORT_MESSAGE_CLASS_NAME, ModelProvider } from "../constant";
import {
EXPORT_MESSAGE_CLASS_NAME,
ModelProvider,
ServiceProvider,
} from "../constant";
import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api";
import { getMessageTextContent } from "../utils";
import { identifyDefaultClaudeModel } from "../utils/checkers";

const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
loading: () => <LoadingIcon />,
Expand Down Expand Up @@ -314,9 +317,9 @@ export function PreviewActions(props: {
setShouldExport(false);

var api: ClientApi;
if (config.modelConfig.model.startsWith("gemini")) {
if (config.modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(config.modelConfig.model)) {
} else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);
Expand Down
Loading

0 comments on commit 2d1f522

Please sign in to comment.