Skip to content

Commit

Permalink
Merge pull request #17 from e-roy:safety-settings-mapper
Browse files Browse the repository at this point in the history
safety settings mapper helper function
  • Loading branch information
e-roy authored Jan 7, 2024
2 parents a34e665 + 8e6c814 commit 9006d1a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 102 deletions.
59 changes: 7 additions & 52 deletions src/app/api/gemini-pro/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,25 @@ import { GoogleGenerativeAIStream, Message, StreamingTextResponse } from "ai";

import {
GoogleGenerativeAI,
HarmCategory,
HarmBlockThreshold,
GenerateContentRequest,
Content,
} from "@google/generative-ai";

import { GeneralSettings } from "@/types";

const mapSafetyValueToThreshold = (value: number): HarmBlockThreshold => {
switch (value) {
case 0:
return HarmBlockThreshold.BLOCK_NONE;
case 1:
return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE;
case 2:
return HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE;
case 3:
return HarmBlockThreshold.BLOCK_ONLY_HIGH;
default:
return HarmBlockThreshold.BLOCK_NONE;
}
};
import {
mapSafetySettings,
defaultSafetySettings,
} from "@/lib/safety-settings-mapper";

const defaultSafetySettings = {
harassment: 0,
hateSpeech: 0,
sexuallyExplicit: 0,
dangerousContent: 0,
};
import { GeneralSettings } from "@/types";

export const runtime = "edge";

export async function POST(req: Request) {
const { messages, general_settings, safety_settings } = await req.json();
const { temperature, maxLength, topP, topK } =
general_settings as GeneralSettings;
// console.log(temperature, maxLength, topP, topK);
// console.log("general_settings", general_settings);
// console.log("safety_settings", safety_settings);
// console.log("messages =================>", messages);

// consecutive user messages need to be merged into the same content, 2 consecutive Content objects will error with the Gemini api
// consecutive user messages need to be merged into the same content, 2 consecutive Content objects with user role will error with the Gemini api
const reqContent: GenerateContentRequest = {
contents: messages.reduce((acc: Content[], m: Message) => {
if (m.role === "user") {
Expand All @@ -69,29 +46,7 @@ export async function POST(req: Request) {
};

const incomingSafetySettings = safety_settings || defaultSafetySettings;

const mappedSafetySettings = [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: mapSafetyValueToThreshold(incomingSafetySettings.harassment),
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: mapSafetyValueToThreshold(incomingSafetySettings.hateSpeech),
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: mapSafetyValueToThreshold(
incomingSafetySettings.sexuallyExplicit
),
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: mapSafetyValueToThreshold(
incomingSafetySettings.dangerousContent
),
},
];
const mappedSafetySettings = mapSafetySettings(incomingSafetySettings);

const genAI = new GoogleGenerativeAI(process.env.GOOGLE_API_KEY as string);

Expand Down
56 changes: 6 additions & 50 deletions src/app/api/gemini-vision/route.ts
Original file line number Diff line number Diff line change
@@ -1,35 +1,17 @@
// api/gemini-vision/route.ts
import { GoogleGenerativeAIStream, StreamingTextResponse } from "ai";

import { GeneralSettings } from "@/types";
import {
GoogleGenerativeAI,
HarmCategory,
HarmBlockThreshold,
GenerateContentRequest,
} from "@google/generative-ai";

const mapSafetyValueToThreshold = (value: number): HarmBlockThreshold => {
switch (value) {
case 0:
return HarmBlockThreshold.BLOCK_NONE;
case 1:
return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE;
case 2:
return HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE;
case 3:
return HarmBlockThreshold.BLOCK_ONLY_HIGH;
default:
return HarmBlockThreshold.BLOCK_NONE;
}
};
import {
mapSafetySettings,
defaultSafetySettings,
} from "@/lib/safety-settings-mapper";

const defaultSafetySettings = {
harassment: 0,
hateSpeech: 0,
sexuallyExplicit: 0,
dangerousContent: 0,
};
import { GeneralSettings } from "@/types";

export const runtime = "edge";

Expand All @@ -38,10 +20,6 @@ export async function POST(req: Request) {
await req.json();
const { temperature, maxLength, topP, topK } =
general_settings as GeneralSettings;
// console.log(temperature, maxLength, topP, topK);
// console.log(media, media_types);
// console.log(safety_settings);
// console.log("message =================>", message);

const userMessage = message;

Expand All @@ -62,29 +40,7 @@ export async function POST(req: Request) {
};

const incomingSafetySettings = safety_settings || defaultSafetySettings;

const mappedSafetySettings = [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: mapSafetyValueToThreshold(incomingSafetySettings.harassment),
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: mapSafetyValueToThreshold(incomingSafetySettings.hateSpeech),
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: mapSafetyValueToThreshold(
incomingSafetySettings.sexuallyExplicit
),
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: mapSafetyValueToThreshold(
incomingSafetySettings.dangerousContent
),
},
];
const mappedSafetySettings = mapSafetySettings(incomingSafetySettings);

const genAI = new GoogleGenerativeAI(process.env.GOOGLE_API_KEY as string);

Expand Down
47 changes: 47 additions & 0 deletions src/lib/safety-settings-mapper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// lib/safety-settings-mapper.ts
import { HarmBlockThreshold, HarmCategory } from "@google/generative-ai";

export const mapSafetyValueToThreshold = (
value: number
): HarmBlockThreshold => {
switch (value) {
case 0:
return HarmBlockThreshold.BLOCK_NONE;
case 1:
return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE;
case 2:
return HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE;
case 3:
return HarmBlockThreshold.BLOCK_ONLY_HIGH;
default:
return HarmBlockThreshold.BLOCK_NONE;
}
};

export const defaultSafetySettings = {
harassment: 0,
hateSpeech: 0,
sexuallyExplicit: 0,
dangerousContent: 0,
};

export const mapSafetySettings = (
safetySettings: typeof defaultSafetySettings
) => [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: mapSafetyValueToThreshold(safetySettings.harassment),
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: mapSafetyValueToThreshold(safetySettings.hateSpeech),
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: mapSafetyValueToThreshold(safetySettings.sexuallyExplicit),
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: mapSafetyValueToThreshold(safetySettings.dangerousContent),
},
];

0 comments on commit 9006d1a

Please sign in to comment.