Skip to content

Commit

Permalink
Merge pull request #18 from e-roy:server-side-validation-and-sanitize…
Browse files Browse the repository at this point in the history
…d-input

server side validation and sanitized input
  • Loading branch information
e-roy authored Jan 7, 2024
2 parents 9006d1a + 78a553f commit fca4f8b
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 26 deletions.
5 changes: 4 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"ai": "^2.2.30",
"class-variance-authority": "^0.7.0",
"clsx": "^2.0.0",
"html-escaper": "^3.0.3",
"lucide-react": "^0.297.0",
"next": "14.0.4",
"next-themes": "^0.2.1",
Expand All @@ -33,9 +34,11 @@
"rehype-sanitize": "^6.0.0",
"remark-gfm": "^4.0.0",
"tailwind-merge": "^2.1.0",
"tailwindcss-animate": "^1.0.7"
"tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4"
},
"devDependencies": {
"@types/html-escaper": "^3.0.2",
"@types/node": "^20",
"@types/react": "^18",
"@types/react-dom": "^18",
Expand Down
26 changes: 24 additions & 2 deletions src/app/api/gemini-pro/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,40 @@ import {
defaultSafetySettings,
} from "@/lib/safety-settings-mapper";

import { sanitizeContent } from "@/lib/sanitize-content";

import { proRequestSchema } from "@/lib/validate/pro-request-schema";

import { GeneralSettings } from "@/types";

export const runtime = "edge";

export async function POST(req: Request) {
const { messages, general_settings, safety_settings } = await req.json();
const parseResult = proRequestSchema.safeParse(await req.json());

if (!parseResult.success) {
// If validation fails, return a 400 Bad Request response
return new Response(JSON.stringify({ error: "Invalid request data" }), {
status: 400,
headers: {
"Content-Type": "application/json",
},
});
}

const { messages, general_settings, safety_settings } = parseResult.data;
const { temperature, maxLength, topP, topK } =
general_settings as GeneralSettings;

for (const message of messages) {
message.content = sanitizeContent(message.content);
}

const typedMessages: Message[] = messages as unknown as Message[];

// consecutive user messages need to be merged into the same content, 2 consecutive Content objects with user role will error with the Gemini api
const reqContent: GenerateContentRequest = {
contents: messages.reduce((acc: Content[], m: Message) => {
contents: typedMessages.reduce((acc: Content[], m: Message) => {
if (m.role === "user") {
const lastContent = acc[acc.length - 1];
if (lastContent && lastContent.role === "user") {
Expand Down
69 changes: 47 additions & 22 deletions src/app/api/gemini-vision/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import { GoogleGenerativeAIStream, StreamingTextResponse } from "ai";
import {
GoogleGenerativeAI,
GenerateContentRequest,
Part,
InlineDataPart,
TextPart,
} from "@google/generative-ai";

import {
Expand All @@ -13,49 +16,71 @@ import {

import { GeneralSettings } from "@/types";

import { visionRequestSchema } from "@/lib/validate/vision-request-schema";

import { sanitizeContent } from "@/lib/sanitize-content";

export const runtime = "edge";

export async function POST(req: Request) {
const parseResult = visionRequestSchema.safeParse(await req.json());

if (!parseResult.success) {
// If validation fails, return a 400 Bad Request response
return new Response(JSON.stringify({ error: "Invalid request data" }), {
status: 400,
headers: {
"Content-Type": "application/json",
},
});
}

const { message, media, media_types, general_settings, safety_settings } =
await req.json();
parseResult.data;

const { temperature, maxLength, topP, topK } =
general_settings as GeneralSettings;

const userMessage = message;
const userMessage: string = sanitizeContent(message);

const incomingSafetySettings = safety_settings || defaultSafetySettings;
const mappedSafetySettings = mapSafetySettings(incomingSafetySettings);

const parts: Part[] = media.map(
(mediaData: string, index: number): InlineDataPart => ({
inlineData: {
mimeType: media_types[index],
data: mediaData,
},
})
);

const userMessagePart: TextPart = { text: userMessage };
parts.push(userMessagePart);

const reqContent: GenerateContentRequest = {
contents: [
{
role: "user",
parts: media
.map((mediaData: string, index: number) => ({
inline_data: {
mime_type: media_types[index],
data: mediaData,
},
}))
.concat([{ text: `"""${userMessage}"""` }]),
parts: parts,
},
],
safetySettings: mappedSafetySettings,
generationConfig: {
// candidateCount: 0,
// stopSequences: [],
maxOutputTokens: maxLength,
temperature,
topP,
topK,
},
};

const incomingSafetySettings = safety_settings || defaultSafetySettings;
const mappedSafetySettings = mapSafetySettings(incomingSafetySettings);

const genAI = new GoogleGenerativeAI(process.env.GOOGLE_API_KEY as string);

const geminiStream = await genAI
.getGenerativeModel({
model: "gemini-pro-vision",
safetySettings: mappedSafetySettings,
generationConfig: {
// candidateCount: 0,
// stopSequences: [],
maxOutputTokens: maxLength,
temperature,
topP,
topK,
},
})
.generateContentStream(reqContent);

Expand Down
2 changes: 1 addition & 1 deletion src/components/markdown-viewer/code.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export const PreComponent: React.FC<IPreProps> = ({ children }) => {
console.error("Failed to copy!", err);
}
}
}, [children.props.children]);
}, [children?.props?.children]);

return (
<pre className="bg-[#2B2B2B] rounded-md p-2 text-neutral-50 mb-4">
Expand Down
6 changes: 6 additions & 0 deletions src/lib/sanitize-content.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// lib/sanitize-content.ts
import { escape } from "html-escaper";

export const sanitizeContent = (content: string): string => {
return escape(content);
};
48 changes: 48 additions & 0 deletions src/lib/validate/common.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// lib/validate/common.ts
import { z } from "zod";
import { JSONValue } from "ai";

export const safetySettingSchema = z.object({
harassment: z.number(),
hateSpeech: z.number(),
sexuallyExplicit: z.number(),
dangerousContent: z.number(),
});

const jsonValueSchema: z.ZodSchema<JSONValue> = z.lazy(() =>
z.union([
z.null(),
z.string(),
z.number(),
z.boolean(),
z.array(jsonValueSchema),
z.record(jsonValueSchema),
])
);

const functionCallSchema = z.object({
arguments: z.string().optional(),
name: z.string().optional(),
});

const toolCallSchema = z.object({
id: z.string(),
type: z.string(),
function: z.object({
name: z.string(),
arguments: z.string(),
}),
});

export const messageSchema = z.object({
id: z.string().optional(),
tool_call_id: z.string().optional(),
createdAt: z.date().optional(),
content: z.string(),
ui: z.any().optional(),
role: z.enum(["system", "user", "assistant", "function", "data", "tool"]),
name: z.string().optional(),
function_call: functionCallSchema.optional(),
data: jsonValueSchema.optional(),
tool_calls: z.array(toolCallSchema).optional(),
});
14 changes: 14 additions & 0 deletions src/lib/validate/pro-request-schema.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// lib/validate/pro-request-schema.ts
import { z } from "zod";
import { messageSchema, safetySettingSchema } from "./common";

export const proRequestSchema = z.object({
messages: z.array(messageSchema),
general_settings: z.object({
temperature: z.number(),
maxLength: z.number(),
topP: z.number(),
topK: z.number(),
}),
safety_settings: safetySettingSchema.optional(),
});
16 changes: 16 additions & 0 deletions src/lib/validate/vision-request-schema.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// lib/validate/vision-request-schema.ts
import { z } from "zod";
import { safetySettingSchema } from "./common";

export const visionRequestSchema = z.object({
message: z.string(),
media: z.array(z.string()),
media_types: z.array(z.string()),
general_settings: z.object({
temperature: z.number(),
maxLength: z.number(),
topP: z.number(),
topK: z.number(),
}),
safety_settings: safetySettingSchema.optional(),
});
15 changes: 15 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,11 @@
dependencies:
"@types/unist" "*"

"@types/html-escaper@^3.0.2":
version "3.0.2"
resolved "https://registry.yarnpkg.com/@types/html-escaper/-/html-escaper-3.0.2.tgz#34d061611e993c67e3f054eae1912e97f6ea0169"
integrity sha512-A8vk09eyYzk8J/lFO4OUMKCmRN0rRzfZf4n3Olwapgox/PtTiU8zPYlL1UEkJ/WeHvV6v9Xnj3o/705PKz9r4Q==

"@types/json5@^0.0.29":
version "0.0.29"
resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.29.tgz#ee28707ae94e11d2b827bcbe5270bcea7f3e71ee"
Expand Down Expand Up @@ -2058,6 +2063,11 @@ highlight.js@^10.4.1, highlight.js@~10.7.0:
resolved "https://registry.yarnpkg.com/highlight.js/-/highlight.js-10.7.3.tgz#697272e3991356e40c3cac566a74eef681756531"
integrity sha512-tzcUFauisWKNHaRkN4Wjl/ZA07gENAjFl3J/c480dprkGTg5EQstgaNFqBfUqCq54kZRIEcreTsAgF/m2quD7A==

html-escaper@^3.0.3:
version "3.0.3"
resolved "https://registry.yarnpkg.com/html-escaper/-/html-escaper-3.0.3.tgz#4d336674652beb1dcbc29ef6b6ba7f6be6fdfed6"
integrity sha512-RuMffC89BOWQoY0WKGpIhn5gX3iI54O6nRA0yC124NYVtzjmFWBIiFd8M0x+ZdX0P9R4lADg1mgP8C7PxGOWuQ==

html-url-attributes@^3.0.0:
version "3.0.0"
resolved "https://registry.yarnpkg.com/html-url-attributes/-/html-url-attributes-3.0.0.tgz#fc4abf0c3fb437e2329c678b80abb3c62cff6f08"
Expand Down Expand Up @@ -4306,6 +4316,11 @@ yocto-queue@^0.1.0:
resolved "https://registry.yarnpkg.com/yocto-queue/-/yocto-queue-0.1.0.tgz#0294eb3dee05028d31ee1a5fa2c556a6aaf10a1b"
integrity sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==

zod@^3.22.4:
version "3.22.4"
resolved "https://registry.yarnpkg.com/zod/-/zod-3.22.4.tgz#f31c3a9386f61b1f228af56faa9255e845cf3fff"
integrity sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==

zwitch@^2.0.0:
version "2.0.4"
resolved "https://registry.yarnpkg.com/zwitch/-/zwitch-2.0.4.tgz#c827d4b0acb76fc3e685a4c6ec2902d51070e9d7"
Expand Down

0 comments on commit fca4f8b

Please sign in to comment.