From 78a553fc86091afa0aa08e34d1acefe12664630e Mon Sep 17 00:00:00 2001 From: e-roy Date: Sun, 7 Jan 2024 13:22:55 -0500 Subject: [PATCH] server side validation and sanitized input --- package.json | 5 +- src/app/api/gemini-pro/route.ts | 26 ++++++++- src/app/api/gemini-vision/route.ts | 69 +++++++++++++++-------- src/components/markdown-viewer/code.tsx | 2 +- src/lib/sanitize-content.ts | 6 ++ src/lib/validate/common.ts | 48 ++++++++++++++++ src/lib/validate/pro-request-schema.ts | 14 +++++ src/lib/validate/vision-request-schema.ts | 16 ++++++ yarn.lock | 15 +++++ 9 files changed, 175 insertions(+), 26 deletions(-) create mode 100644 src/lib/sanitize-content.ts create mode 100644 src/lib/validate/common.ts create mode 100644 src/lib/validate/pro-request-schema.ts create mode 100644 src/lib/validate/vision-request-schema.ts diff --git a/package.json b/package.json index d0e50b4..33a2556 100644 --- a/package.json +++ b/package.json @@ -21,6 +21,7 @@ "ai": "^2.2.30", "class-variance-authority": "^0.7.0", "clsx": "^2.0.0", + "html-escaper": "^3.0.3", "lucide-react": "^0.297.0", "next": "14.0.4", "next-themes": "^0.2.1", @@ -33,9 +34,11 @@ "rehype-sanitize": "^6.0.0", "remark-gfm": "^4.0.0", "tailwind-merge": "^2.1.0", - "tailwindcss-animate": "^1.0.7" + "tailwindcss-animate": "^1.0.7", + "zod": "^3.22.4" }, "devDependencies": { + "@types/html-escaper": "^3.0.2", "@types/node": "^20", "@types/react": "^18", "@types/react-dom": "^18", diff --git a/src/app/api/gemini-pro/route.ts b/src/app/api/gemini-pro/route.ts index de83755..c987ec3 100644 --- a/src/app/api/gemini-pro/route.ts +++ b/src/app/api/gemini-pro/route.ts @@ -12,18 +12,40 @@ import { defaultSafetySettings, } from "@/lib/safety-settings-mapper"; +import { sanitizeContent } from "@/lib/sanitize-content"; + +import { proRequestSchema } from "@/lib/validate/pro-request-schema"; + import { GeneralSettings } from "@/types"; export const runtime = "edge"; export async function POST(req: Request) { - const { messages, general_settings, safety_settings } = await req.json(); + const parseResult = proRequestSchema.safeParse(await req.json()); + + if (!parseResult.success) { + // If validation fails, return a 400 Bad Request response + return new Response(JSON.stringify({ error: "Invalid request data" }), { + status: 400, + headers: { + "Content-Type": "application/json", + }, + }); + } + + const { messages, general_settings, safety_settings } = parseResult.data; const { temperature, maxLength, topP, topK } = general_settings as GeneralSettings; + for (const message of messages) { + message.content = sanitizeContent(message.content); + } + + const typedMessages: Message[] = messages as unknown as Message[]; + // consecutive user messages need to be merged into the same content, 2 consecutive Content objects with user role will error with the Gemini api const reqContent: GenerateContentRequest = { - contents: messages.reduce((acc: Content[], m: Message) => { + contents: typedMessages.reduce((acc: Content[], m: Message) => { if (m.role === "user") { const lastContent = acc[acc.length - 1]; if (lastContent && lastContent.role === "user") { diff --git a/src/app/api/gemini-vision/route.ts b/src/app/api/gemini-vision/route.ts index 6f7729b..5619fb7 100644 --- a/src/app/api/gemini-vision/route.ts +++ b/src/app/api/gemini-vision/route.ts @@ -4,6 +4,9 @@ import { GoogleGenerativeAIStream, StreamingTextResponse } from "ai"; import { GoogleGenerativeAI, GenerateContentRequest, + Part, + InlineDataPart, + TextPart, } from "@google/generative-ai"; import { @@ -13,49 +16,71 @@ import { import { GeneralSettings } from "@/types"; +import { visionRequestSchema } from "@/lib/validate/vision-request-schema"; + +import { sanitizeContent } from "@/lib/sanitize-content"; + export const runtime = "edge"; export async function POST(req: Request) { + const parseResult = visionRequestSchema.safeParse(await req.json()); + + if (!parseResult.success) { + // If validation fails, return a 400 Bad Request response + return new Response(JSON.stringify({ error: "Invalid request data" }), { + status: 400, + headers: { + "Content-Type": "application/json", + }, + }); + } + const { message, media, media_types, general_settings, safety_settings } = - await req.json(); + parseResult.data; + const { temperature, maxLength, topP, topK } = general_settings as GeneralSettings; - const userMessage = message; + const userMessage: string = sanitizeContent(message); + + const incomingSafetySettings = safety_settings || defaultSafetySettings; + const mappedSafetySettings = mapSafetySettings(incomingSafetySettings); + + const parts: Part[] = media.map( + (mediaData: string, index: number): InlineDataPart => ({ + inlineData: { + mimeType: media_types[index], + data: mediaData, + }, + }) + ); + + const userMessagePart: TextPart = { text: userMessage }; + parts.push(userMessagePart); const reqContent: GenerateContentRequest = { contents: [ { role: "user", - parts: media - .map((mediaData: string, index: number) => ({ - inline_data: { - mime_type: media_types[index], - data: mediaData, - }, - })) - .concat([{ text: `"""${userMessage}"""` }]), + parts: parts, }, ], + safetySettings: mappedSafetySettings, + generationConfig: { + // candidateCount: 0, + // stopSequences: [], + maxOutputTokens: maxLength, + temperature, + topP, + topK, + }, }; - const incomingSafetySettings = safety_settings || defaultSafetySettings; - const mappedSafetySettings = mapSafetySettings(incomingSafetySettings); - const genAI = new GoogleGenerativeAI(process.env.GOOGLE_API_KEY as string); const geminiStream = await genAI .getGenerativeModel({ model: "gemini-pro-vision", - safetySettings: mappedSafetySettings, - generationConfig: { - // candidateCount: 0, - // stopSequences: [], - maxOutputTokens: maxLength, - temperature, - topP, - topK, - }, }) .generateContentStream(reqContent); diff --git a/src/components/markdown-viewer/code.tsx b/src/components/markdown-viewer/code.tsx index aad43ce..b505d50 100644 --- a/src/components/markdown-viewer/code.tsx +++ b/src/components/markdown-viewer/code.tsx @@ -23,7 +23,7 @@ export const PreComponent: React.FC = ({ children }) => { console.error("Failed to copy!", err); } } - }, [children.props.children]); + }, [children?.props?.children]); return (
diff --git a/src/lib/sanitize-content.ts b/src/lib/sanitize-content.ts
new file mode 100644
index 0000000..a6004e9
--- /dev/null
+++ b/src/lib/sanitize-content.ts
@@ -0,0 +1,6 @@
+// lib/sanitize-content.ts
+import { escape } from "html-escaper";
+
+export const sanitizeContent = (content: string): string => {
+  return escape(content);
+};
diff --git a/src/lib/validate/common.ts b/src/lib/validate/common.ts
new file mode 100644
index 0000000..7611f88
--- /dev/null
+++ b/src/lib/validate/common.ts
@@ -0,0 +1,48 @@
+// lib/validate/common.ts
+import { z } from "zod";
+import { JSONValue } from "ai";
+
+export const safetySettingSchema = z.object({
+  harassment: z.number(),
+  hateSpeech: z.number(),
+  sexuallyExplicit: z.number(),
+  dangerousContent: z.number(),
+});
+
+const jsonValueSchema: z.ZodSchema = z.lazy(() =>
+  z.union([
+    z.null(),
+    z.string(),
+    z.number(),
+    z.boolean(),
+    z.array(jsonValueSchema),
+    z.record(jsonValueSchema),
+  ])
+);
+
+const functionCallSchema = z.object({
+  arguments: z.string().optional(),
+  name: z.string().optional(),
+});
+
+const toolCallSchema = z.object({
+  id: z.string(),
+  type: z.string(),
+  function: z.object({
+    name: z.string(),
+    arguments: z.string(),
+  }),
+});
+
+export const messageSchema = z.object({
+  id: z.string().optional(),
+  tool_call_id: z.string().optional(),
+  createdAt: z.date().optional(),
+  content: z.string(),
+  ui: z.any().optional(),
+  role: z.enum(["system", "user", "assistant", "function", "data", "tool"]),
+  name: z.string().optional(),
+  function_call: functionCallSchema.optional(),
+  data: jsonValueSchema.optional(),
+  tool_calls: z.array(toolCallSchema).optional(),
+});
diff --git a/src/lib/validate/pro-request-schema.ts b/src/lib/validate/pro-request-schema.ts
new file mode 100644
index 0000000..3f34c1e
--- /dev/null
+++ b/src/lib/validate/pro-request-schema.ts
@@ -0,0 +1,14 @@
+// lib/validate/pro-request-schema.ts
+import { z } from "zod";
+import { messageSchema, safetySettingSchema } from "./common";
+
+export const proRequestSchema = z.object({
+  messages: z.array(messageSchema),
+  general_settings: z.object({
+    temperature: z.number(),
+    maxLength: z.number(),
+    topP: z.number(),
+    topK: z.number(),
+  }),
+  safety_settings: safetySettingSchema.optional(),
+});
diff --git a/src/lib/validate/vision-request-schema.ts b/src/lib/validate/vision-request-schema.ts
new file mode 100644
index 0000000..f9ab0f2
--- /dev/null
+++ b/src/lib/validate/vision-request-schema.ts
@@ -0,0 +1,16 @@
+// lib/validate/vision-request-schema.ts
+import { z } from "zod";
+import { safetySettingSchema } from "./common";
+
+export const visionRequestSchema = z.object({
+  message: z.string(),
+  media: z.array(z.string()),
+  media_types: z.array(z.string()),
+  general_settings: z.object({
+    temperature: z.number(),
+    maxLength: z.number(),
+    topP: z.number(),
+    topK: z.number(),
+  }),
+  safety_settings: safetySettingSchema.optional(),
+});
diff --git a/yarn.lock b/yarn.lock
index 1a9dc8b..320642a 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -598,6 +598,11 @@
   dependencies:
     "@types/unist" "*"
 
+"@types/html-escaper@^3.0.2":
+  version "3.0.2"
+  resolved "https://registry.yarnpkg.com/@types/html-escaper/-/html-escaper-3.0.2.tgz#34d061611e993c67e3f054eae1912e97f6ea0169"
+  integrity sha512-A8vk09eyYzk8J/lFO4OUMKCmRN0rRzfZf4n3Olwapgox/PtTiU8zPYlL1UEkJ/WeHvV6v9Xnj3o/705PKz9r4Q==
+
 "@types/json5@^0.0.29":
   version "0.0.29"
   resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.29.tgz#ee28707ae94e11d2b827bcbe5270bcea7f3e71ee"
@@ -2058,6 +2063,11 @@ highlight.js@^10.4.1, highlight.js@~10.7.0:
   resolved "https://registry.yarnpkg.com/highlight.js/-/highlight.js-10.7.3.tgz#697272e3991356e40c3cac566a74eef681756531"
   integrity sha512-tzcUFauisWKNHaRkN4Wjl/ZA07gENAjFl3J/c480dprkGTg5EQstgaNFqBfUqCq54kZRIEcreTsAgF/m2quD7A==
 
+html-escaper@^3.0.3:
+  version "3.0.3"
+  resolved "https://registry.yarnpkg.com/html-escaper/-/html-escaper-3.0.3.tgz#4d336674652beb1dcbc29ef6b6ba7f6be6fdfed6"
+  integrity sha512-RuMffC89BOWQoY0WKGpIhn5gX3iI54O6nRA0yC124NYVtzjmFWBIiFd8M0x+ZdX0P9R4lADg1mgP8C7PxGOWuQ==
+
 html-url-attributes@^3.0.0:
   version "3.0.0"
   resolved "https://registry.yarnpkg.com/html-url-attributes/-/html-url-attributes-3.0.0.tgz#fc4abf0c3fb437e2329c678b80abb3c62cff6f08"
@@ -4306,6 +4316,11 @@ yocto-queue@^0.1.0:
   resolved "https://registry.yarnpkg.com/yocto-queue/-/yocto-queue-0.1.0.tgz#0294eb3dee05028d31ee1a5fa2c556a6aaf10a1b"
   integrity sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==
 
+zod@^3.22.4:
+  version "3.22.4"
+  resolved "https://registry.yarnpkg.com/zod/-/zod-3.22.4.tgz#f31c3a9386f61b1f228af56faa9255e845cf3fff"
+  integrity sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==
+
 zwitch@^2.0.0:
   version "2.0.4"
   resolved "https://registry.yarnpkg.com/zwitch/-/zwitch-2.0.4.tgz#c827d4b0acb76fc3e685a4c6ec2902d51070e9d7"