diff --git a/apis/cloudflare/.gitignore b/apis/cloudflare/.gitignore index 59a3378..483b1a1 100644 --- a/apis/cloudflare/.gitignore +++ b/apis/cloudflare/.gitignore @@ -172,3 +172,5 @@ dist .dev.vars wrangler.toml + +certs/ diff --git a/apis/cloudflare/e2e/auth.test.ts b/apis/cloudflare/e2e/auth.test.ts new file mode 100644 index 0000000..403abce --- /dev/null +++ b/apis/cloudflare/e2e/auth.test.ts @@ -0,0 +1,72 @@ +import { createRandomToken } from "./utils/testAuth"; +import { expect, test } from "bun:test"; +import { PROXY, TOKEN } from "./utils/constants"; + +const COMMON_HEADERS = { + "Content-Type": "application/json", + "X-Grit-Api": TOKEN, +}; + +const OPENAI_HEADERS_MISSING_GRIT_KEY = { + "Content-Type": "application/json", +}; + +test("auth__failsWhenMissingGritKey", async () => { + const res = await fetch(PROXY, { + headers: OPENAI_HEADERS_MISSING_GRIT_KEY, + method: "POST", + body: JSON.stringify({ + model: "gpt-3.5-turbo", + messages: [ + { + role: "user", + content: "What is a proxy?", + }, + ], + seed: 8, + }), + }); + + expect(res.status).toBe(400); +}); + +test("auth__failsWhenInvalidGritKey", async () => { + const res = await fetch(PROXY, { + headers: { ...OPENAI_HEADERS_MISSING_GRIT_KEY, "X-Grit-Api": "junk" }, + method: "POST", + body: JSON.stringify({ + model: "gpt-3.5-turbo", + messages: [ + { + role: "user", + content: "What is a proxy?", + }, + ], + seed: 1, + }), + }); + + expect(res.status).toBe(401); +}); + +test("auth__succeedsWhenProvidedGritKey", async () => { + const res = await fetch(PROXY, { + headers: COMMON_HEADERS, + method: "POST", + body: JSON.stringify({ + model: "gpt-3.5-turbo", + messages: [ + { + role: "user", + content: "What is a proxy?", + }, + ], + seed: 1, + }), + }); + + expect(res.status).toBe(200); + + const antCompletions = await res.json(); + expect(antCompletions.choices.length >= 1).toBe(true); +}); diff --git a/apis/cloudflare/e2e/basic.test.ts b/apis/cloudflare/e2e/basic.test.ts new file mode 100644 index 0000000..5bc0dbd --- /dev/null +++ b/apis/cloudflare/e2e/basic.test.ts @@ -0,0 +1,50 @@ +import { PROXY, TOKEN } from "./utils/constants"; +import { createRandomToken } from "./utils/testAuth"; +import { expect, test } from "bun:test"; + +const COMMON_HEADERS = { + "Content-Type": "application/json", + "X-Grit-Api": TOKEN, +}; + +test("basic__routesBetweenModels", async () => { + const res = await fetch(PROXY, { + headers: COMMON_HEADERS, + method: "POST", + body: JSON.stringify({ + model: "gpt-3.5-turbo", + messages: [ + { + role: "user", + content: "What is a proxy?", + }, + ], + seed: 1, + }), + }); + + expect(res.status).toBe(200); + + const openAiCompletions = await res.json(); + expect(openAiCompletions.choices.length >= 1).toBe(true); + + const antRes = await fetch(PROXY, { + headers: COMMON_HEADERS, + method: "POST", + body: JSON.stringify({ + model: "claude-2.1", + messages: [ + { + role: "user", + content: "What is a proxy?", + }, + ], + seed: 1, + }), + }); + + expect(antRes.status).toBe(200); + + const antCompletions = await antRes.json(); + expect(antCompletions.choices.length >= 1).toBe(true); +}); diff --git a/apis/cloudflare/e2e/utils/constants.ts b/apis/cloudflare/e2e/utils/constants.ts new file mode 100644 index 0000000..f8df1bf --- /dev/null +++ b/apis/cloudflare/e2e/utils/constants.ts @@ -0,0 +1,6 @@ +const ROOT = process.env.PROXY ?? "localhost:62952"; + +export const PROXY = `${ROOT}/v1/chat/completions`; + +export const TOKEN = + "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMiwiaXNzIjoiZ3JpdC1sbG0tcm91dGVyIn0.Hx2Z14WQdNB_bG8BwApb3WVKfReQo8krOPxtsTHhjgTjV45pMo197uRy0SdgFHfVg0_t0h705XNgF5zBv-DXqsiqBagNr9yip8vSKGltSAd04PEiFsRUlLxaWuMqHfBEhfeSA0x46u14UytOBV38auz4FPC3IzV_E7_a5fs6X2MGUCm1bALp9vZh5KBHNBxUpb7yNT5T5f5pzxNP7qyMdttEndSJgkLZbjpFOyZqAkgs59MyqXxlasxnk5aWAjEn7S4HPkLhB8btTAwwjRXQpnDZQ7xjtrA9bB2ZnQIq8g6xpI6J_vmHZkLskednoRHOSRt4_BkpOsGxtshn72UHoVqWNZTbspvqZANQlB5CnqiAqO2mlmUk4Mrk4Gg89BioMhYCDOUps-qlyDuXeRSMZn2H81BYh6oT-dai-qE5zKWy2Z1lOAdwWzgndn0qBtPvUvyIqe7wqGtNLEvRccRIMGuSv68GcGWuxhE9rRK1w1eOcM5I7I-sYlELbhX_qJ147azxiFjULzJDS8jDge0npMDUJVwpX_BAaG3A2qLQhZMFduQrCbfM4NqlKQ0qTb6-M1I6dOv-zg7A0Q3jcCLFG0bFZczMACmutYME-m44EK7ruIvc6pWom9bN2wtjsqoXS0FCtkji55MGA_hBGVmB8qgOAHg_JxYssuz2tBnv6e4"; diff --git a/apis/cloudflare/package.json b/apis/cloudflare/package.json index a2afbb9..a54b33d 100644 --- a/apis/cloudflare/package.json +++ b/apis/cloudflare/package.json @@ -8,19 +8,23 @@ "dev": "wrangler dev --port 8787 --inspector-port 9299", "start": "wrangler dev", "watch": "tsup --watch --dts", - "build": "tsup --clean --dts" + "build": "tsup --clean --dts", + "test": "bun --env-file=.env test", + "test:prod": "export PROXY=https://proxy.admin-a65.workers.dev && bun --env-file=.env test" }, "devDependencies": { "@cloudflare/workers-types": "^4.20240512.0", + "@types/jsonwebtoken": "^9.0.6", "itty-router": "^3.0.12", + "tsup": "^8.0.1", "typescript": "^5.0.4", - "wrangler": "^3.57.1", - "tsup": "^8.0.1" + "wrangler": "^3.57.1" }, "dependencies": { "@braintrust/proxy": "workspace:*", "@opentelemetry/resources": "^1.18.1", "@opentelemetry/sdk-metrics": "^1.18.1", - "dotenv": "^16.3.1" + "dotenv": "^16.3.1", + "jose": "^5.6.3" } } diff --git a/apis/cloudflare/src/auth.ts b/apis/cloudflare/src/auth.ts new file mode 100644 index 0000000..4f61a00 --- /dev/null +++ b/apis/cloudflare/src/auth.ts @@ -0,0 +1,30 @@ +import { verify } from "jsonwebtoken"; +import * as jose from "jose"; + +export async function authenticateToken( + token: string, + env: Env, +): Promise { + if (!env.JWT_PUB_KEY) { + throw Error("Expected JWT_PUB_KEY in env"); + } + + const pubKey = await jose.importSPKI(env.JWT_PUB_KEY, "RS256"); + + try { + await jose.jwtVerify(token, pubKey, { + issuer: "grit-llm-router", + }); + } catch (error) { + console.error(`Error verifying token ${error}`); + return false; + } + + return true; +} + +export function parseGritToken(headers: Headers): string | null { + const authHeader = headers.get("X-Grit-Api"); + + return authHeader; +} diff --git a/apis/cloudflare/src/keyPicker.ts b/apis/cloudflare/src/keyPicker.ts new file mode 100644 index 0000000..40126b6 --- /dev/null +++ b/apis/cloudflare/src/keyPicker.ts @@ -0,0 +1,19 @@ +type KeyNames = "OPENAI_API_KEY" | "ANT_API_KEY"; +const modelToKeyName: Map = new Map([ + ["gpt-3.5-turbo", "OPENAI_API_KEY"], + ["claude-2.1", "ANT_API_KEY"], +]); + +export function getProviderKey(modelName: string, env: Env): string | null { + if (!modelToKeyName.has(modelName)) { + return null; + } + + const keyName = modelToKeyName.get(modelName); + + if (!keyName) { + return null; + } + + return env[keyName] ?? null; +} diff --git a/apis/cloudflare/src/proxy.ts b/apis/cloudflare/src/proxy.ts index f8bac2c..c492814 100644 --- a/apis/cloudflare/src/proxy.ts +++ b/apis/cloudflare/src/proxy.ts @@ -1,6 +1,8 @@ import { EdgeProxyV1, FlushingExporter } from "@braintrust/proxy/edge"; import { NOOP_METER_PROVIDER, initMetrics } from "@braintrust/proxy"; import { PrometheusMetricAggregator } from "./metric-aggregator"; +import { authenticateToken, parseGritToken } from "./auth"; +import { getProviderKey } from "./keyPicker"; export const proxyV1Prefix = "/v1"; @@ -12,6 +14,10 @@ declare global { PROMETHEUS_SCRAPE_USER?: string; PROMETHEUS_SCRAPE_PASSWORD?: string; WHITELISTED_ORIGINS?: string; + JWT_PUB_KEY?: string; + JWT_SECRET?: string; + OPENAI_API_KEY?: string; + ANT_API_KEY?: string; } } @@ -64,6 +70,64 @@ export async function handleProxyV1( "cloudflare-metrics", ); + const gritToken = parseGritToken(request.headers); + + if (!gritToken) { + return new Response("Missing X-Grit-Api Header", { + status: 400, + headers: { + "Content-Type": "text/plain", + }, + }); + } + + const isAuthed = await authenticateToken(gritToken, env); + + if (!isAuthed) { + return new Response("Invalid X-Grit-Api", { + status: 401, + headers: { + "Content-Type": "text/plain", + }, + }); + } + + const clonedRequest = request.clone(); + + const body = await clonedRequest.json(); + + if ( + typeof body !== "object" || + !body || + !("model" in body) || + typeof body.model !== "string" + ) { + return new Response("Expected model in body", { + status: 400, + headers: { + "Content-Type": "text/plain", + }, + }); + } + + const providerKey = getProviderKey(body.model, env); + + if (!providerKey) { + return new Response(`Model ${body.model} not found`, { + status: 404, + headers: { + "Content-Type": "text/plain", + }, + }); + } + + const headers = new Headers(request.headers); + headers.set("Authorization", `Bearer ${providerKey}`); + + const reqWithKey = new Request(request, { + headers, + }); + const whitelist = originWhitelist(env); const cacheGetLatency = meter.createHistogram("results_cache_get_latency"); @@ -120,7 +184,7 @@ export async function handleProxyV1( braintrustApiUrl: braintrustAppUrl(env).toString(), meterProvider, whitelist, - })(request, ctx); + })(reqWithKey, ctx); } export async function handlePrometheusScrape( diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 566d3eb..0c77d87 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -32,10 +32,16 @@ importers: dotenv: specifier: ^16.3.1 version: 16.3.1 + jose: + specifier: ^5.6.3 + version: 5.6.3 devDependencies: '@cloudflare/workers-types': specifier: ^4.20240512.0 version: 4.20240512.0 + '@types/jsonwebtoken': + specifier: ^9.0.6 + version: 9.0.6 itty-router: specifier: ^3.0.12 version: 3.0.12 @@ -2230,6 +2236,12 @@ packages: resolution: {integrity: sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==} dev: true + /@types/jsonwebtoken@9.0.6: + resolution: {integrity: sha512-/5hndP5dCjloafCXns6SZyESp3Ldq7YjH3zwzwczYnjxIT0Fqzk5ROSYVGfFyczIue7IUEj8hkvLbPoLQ18vQw==} + dependencies: + '@types/node': 20.10.5 + dev: true + /@types/mime@1.3.5: resolution: {integrity: sha512-/pyBZWSLD2n0dcHE3hq8s8ZvcETHtEuF+3E7XVt0Ig2nvsVQXdghHVcEkIWjy9A0wKfTn97a/PSDYohKIlnP/w==} dev: true @@ -4641,6 +4653,10 @@ packages: engines: {node: '>= 0.6.0'} dev: false + /jose@5.6.3: + resolution: {integrity: sha512-1Jh//hEEwMhNYPDDLwXHa2ePWgWiFNNUadVmguAAw2IJ6sj9mNxV5tGXJNqlMkJAybF6Lgw1mISDxTePP/187g==} + dev: false + /joycon@3.1.1: resolution: {integrity: sha512-34wB/Y7MW7bzjKRjUKTa46I2Z7eV62Rkhva+KkopW7Qvv/OSWBqvkSY7vusOPrNuZcUG3tApvdVgNB8POj3SPw==} engines: {node: '>=10'}