From 11bee95eb8bb1ad22b737c706b8aaf0f573925c7 Mon Sep 17 00:00:00 2001 From: e-roy Date: Tue, 26 Dec 2023 06:27:51 -0500 Subject: [PATCH] update streaming --- package.json | 2 +- src/app/api/gemini-pro/route.ts | 69 ++++++++++-------------------- src/app/api/gemini-vision/route.ts | 61 ++++++++------------------ yarn.lock | 8 ++-- 4 files changed, 46 insertions(+), 94 deletions(-) diff --git a/package.json b/package.json index ade9e3d..d0e50b4 100644 --- a/package.json +++ b/package.json @@ -18,7 +18,7 @@ "@radix-ui/react-label": "^2.0.2", "@radix-ui/react-slider": "^1.1.2", "@radix-ui/react-slot": "^1.0.2", - "ai": "^2.2.29", + "ai": "^2.2.30", "class-variance-authority": "^0.7.0", "clsx": "^2.0.0", "lucide-react": "^0.297.0", diff --git a/src/app/api/gemini-pro/route.ts b/src/app/api/gemini-pro/route.ts index c5e90b3..43792fe 100644 --- a/src/app/api/gemini-pro/route.ts +++ b/src/app/api/gemini-pro/route.ts @@ -1,5 +1,5 @@ // api/gemini/route.ts -import { Message } from "ai"; +import { GoogleGenerativeAIStream, Message, StreamingTextResponse } from "ai"; import { GoogleGenerativeAI, @@ -87,53 +87,30 @@ export async function POST(req: Request) { ]; const genAI = new GoogleGenerativeAI(process.env.GOOGLE_API_KEY as string); - const model = genAI.getGenerativeModel({ - model: "gemini-pro", - safetySettings: mappedSafetySettings, - generationConfig: { - // candidateCount: 0, - // stopSequences: [], - maxOutputTokens: maxLength, - temperature, - topP, - topK, - }, - }); - - const countTokens = await model.countTokens(reqContent); - console.log("count tokens ------", countTokens); - try { - const streamingResp = await model.generateContentStream(reqContent); + const tokens = await genAI + .getGenerativeModel({ + model: "gemini-pro", + }) + .countTokens(reqContent); + console.log("count tokens ------", tokens); - const stream = new ReadableStream({ - async start(controller) { - try { - for await (const chunk of streamingResp.stream) { - if (chunk.candidates) { - const parts = chunk.candidates[0].content.parts; - const firstPart = parts[0]; - if (typeof firstPart.text === "string") { - // Encode the string text as bytes - const textEncoder = new TextEncoder(); - const encodedText = textEncoder.encode(firstPart.text); - controller.enqueue(encodedText); - } - } - } - controller.close(); - } catch (error) { - console.error("Streaming error:", error); - controller.error(error); - } + const geminiStream = await genAI + .getGenerativeModel({ + model: "gemini-pro", + safetySettings: mappedSafetySettings, + generationConfig: { + // candidateCount: 0, + // stopSequences: [], + maxOutputTokens: maxLength, + temperature, + topP, + topK, }, - }); + }) + .generateContentStream(reqContent); - return new Response(stream, { - headers: { "Content-Type": "text/plain" }, - }); - } catch (error) { - console.error("API error:", error); - return new Response("Internal Server Error", { status: 500 }); - } + const stream = GoogleGenerativeAIStream(geminiStream); + + return new StreamingTextResponse(stream); } diff --git a/src/app/api/gemini-vision/route.ts b/src/app/api/gemini-vision/route.ts index 5d89769..49a361a 100644 --- a/src/app/api/gemini-vision/route.ts +++ b/src/app/api/gemini-vision/route.ts @@ -1,4 +1,6 @@ // api/gemini-vision/route.ts +import { GoogleGenerativeAIStream, StreamingTextResponse } from "ai"; + import { GeneralSettings } from "@/types"; import { GoogleGenerativeAI, @@ -86,49 +88,22 @@ export async function POST(req: Request) { const genAI = new GoogleGenerativeAI(process.env.GOOGLE_API_KEY as string); - const model = genAI.getGenerativeModel({ - model: "gemini-pro-vision", - safetySettings: mappedSafetySettings, - generationConfig: { - // candidateCount: 0, - // stopSequences: [], - maxOutputTokens: maxLength, - temperature, - topP, - topK, - }, - }); - - try { - const streamingResp = await model.generateContentStream(reqContent); - - const stream = new ReadableStream({ - async start(controller) { - try { - for await (const chunk of streamingResp.stream) { - if (chunk.candidates) { - const parts = chunk.candidates[0].content.parts; - const firstPart = parts[0]; - if (typeof firstPart.text === "string") { - const textEncoder = new TextEncoder(); - const encodedText = textEncoder.encode(firstPart.text); - controller.enqueue(encodedText); - } - } - } - controller.close(); - } catch (error) { - console.error("Streaming error:", error); - controller.error(error); - } + const geminiStream = await genAI + .getGenerativeModel({ + model: "gemini-pro-vision", + safetySettings: mappedSafetySettings, + generationConfig: { + // candidateCount: 0, + // stopSequences: [], + maxOutputTokens: maxLength, + temperature, + topP, + topK, }, - }); + }) + .generateContentStream(reqContent); - return new Response(stream, { - headers: { "Content-Type": "text/plain" }, - }); - } catch (error) { - console.error("API error:", error); - return new Response("Internal Server Error", { status: 500 }); - } + const stream = GoogleGenerativeAIStream(geminiStream); + + return new StreamingTextResponse(stream); } diff --git a/yarn.lock b/yarn.lock index e62435a..1a9dc8b 100644 --- a/yarn.lock +++ b/yarn.lock @@ -732,10 +732,10 @@ agent-base@^7.0.2: dependencies: debug "^4.3.4" -ai@^2.2.29: - version "2.2.29" - resolved "https://registry.yarnpkg.com/ai/-/ai-2.2.29.tgz#a0522aff58be764c2d9c18c76806b9a8101194c2" - integrity sha512-/zzSTTKF5LxMGQuNVUnNjs7X6PWYfb6M88Zn74gCUnM3KCYgh0CiAWhLyhKP6UtK0H5mHSmXgt0ZkZYUecRp0w== +ai@^2.2.30: + version "2.2.30" + resolved "https://registry.yarnpkg.com/ai/-/ai-2.2.30.tgz#3b0af50859b44cfc7e168c170974a93637e90792" + integrity sha512-7dRgnEbYkbVjxyjiS7WEhNvO8ebeI4Om74D9OKXLK0yis4+s272pJ5I3vOAv3HaUBbVEiIFYQ7E34JH8XT1EeQ== dependencies: eventsource-parser "1.0.0" nanoid "3.3.6"