Skip to content

Commit

Permalink
feat: added azp check
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas-londhe committed Sep 20, 2024
1 parent 97ee4a8 commit 9aee0fe
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 14 deletions.
25 changes: 17 additions & 8 deletions packages/circuits/jwt-verifier.circom
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ include "./utils/constants.circom";
*
* @output sha[256] The SHA256 hash of the JWT message, computed for signature verification.
*/
template JWTVerifier(n, k, maxMessageLength, maxB64HeaderLength, maxB64PayloadLength) {
template JWTVerifier(n, k, maxMessageLength, maxB64HeaderLength, maxB64PayloadLength, azpLength) {
signal input message[maxMessageLength]; // JWT message (header + payload)
signal input messageLength; // Length of the message signed in the JWT
signal input pubkey[k]; // RSA public key split into k chunks
Expand All @@ -37,7 +37,9 @@ template JWTVerifier(n, k, maxMessageLength, maxB64HeaderLength, maxB64PayloadLe

signal input jwtTypStartIndex; // Index of the "typ" in the JWT header
signal input jwtAlgStartIndex; // Index of the "alg" in the JWT header
signal input commandStartIndex; // Index of the key `command` in the JWT payload

signal input azpKeyStartIndex; // Index of the "azp" (Authorized party) key in the JWT payload
signal input azp[azpLength]; // "azp" (Authorized party) in the JWT payload

assert(maxMessageLength % 64 == 0);
assert(n * k > 2048); // to support 2048 bit RSA
Expand Down Expand Up @@ -122,11 +124,18 @@ template JWTVerifier(n, k, maxMessageLength, maxB64HeaderLength, maxB64PayloadLe
algMatch[i] === alg[i];
}

// Verify if `command` key exists in the payload
var commandLength = COMMAND_LENGTH();
var command[commandLength] = COMMAND();
signal commandMatch[commandLength] <== RevealSubstring(maxPayloadLength, commandLength, 1)(payload, commandStartIndex, commandLength);
for (var i = 0; i < commandLength; i++) {
commandMatch[i] === command[i];
// Verify if the key `azp` in the payload is unique
var azpKeyLength = AZP_KEY_LENGTH();
var azpKey[azpKeyLength] = AZP_KEY();
signal azpKeyMatch[azpKeyLength] <== RevealSubstring(maxPayloadLength, azpKeyLength, 1)(payload, azpKeyStartIndex, azpKeyLength);
for (var i = 0; i < azpKeyLength; i++) {
azpKeyMatch[i] === azpKey[i];
}

// Verify if azp is correct
signal azpStartIndex <== azpKeyStartIndex + azpKeyLength + 1;
signal azpMatch[azpLength] <== RevealSubstring(maxPayloadLength, azpLength, 0)(payload, azpStartIndex, azpLength);
for (var i = 0; i < azpLength; i++) {
azpMatch[i] === azp[i];
}
}
3 changes: 2 additions & 1 deletion packages/circuits/tests/jwt-verifier.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ describe("JWT Verifier Circuit", () => {
typ: "JWT",
};
const payload = {
command: "register",
name: "John Doe",
iat: Math.floor(Date.now() / 1000),
azp: "demo-client-id",
};

const { rawJWT: jwt, publicKey: key } = generateJWT(header, payload);
Expand All @@ -50,6 +50,7 @@ describe("JWT Verifier Circuit", () => {
},
{
maxMessageLength: 256,
azp: "demo-client-id",
}
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ pragma circom 2.1.6;

include "../../jwt-verifier.circom";

component main { public [ pubkey ] } = JWTVerifier(121, 17, 256, 64, 96);
component main { public [ pubkey ] } = JWTVerifier(121, 17, 256, 64, 96, 14);
10 changes: 10 additions & 0 deletions packages/circuits/utils/constants.circom
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ function JWT_ALG_LENGTH() {
return 13;
}

function AZP_KEY_LENGTH() {
// len("azp":)
return 6;
}

function COMMAND_LENGTH() {
// len("command":)
return 10;
Expand All @@ -25,6 +30,11 @@ function JWT_ALG() {
return [34, 97, 108, 103, 34, 58, 34, 82, 83, 50, 53, 54, 34];
}

function AZP_KEY() {
// "azp":
return [34, 97, 122, 112, 34, 58];
}

function COMMAND() {
// "command":
return [34, 99, 111, 109, 109, 97, 110, 100, 34, 58];
Expand Down
9 changes: 6 additions & 3 deletions packages/helpers/src/input-generators.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export interface RSAPublicKey {

type JWTInputGenerationArgs = {
maxMessageLength?: number; // Max length of the JWT message including padding
azp?: string; // "azp" (Authorized party) in the JWT payload
};

/**
Expand All @@ -36,7 +37,8 @@ export async function generateJWTVerifierInputs(
periodIndex: string; // Index of the period in the JWT message
jwtTypStartIndex: string; // Index of the "typ" in the JWT header
jwtAlgStartIndex: string; // Index of the "alg" in the JWT header
commandStartIndex: string; // Index of the "command" in the JWT payload
azpKeyStartIndex: string; // Index of the "azp" in the JWT payload
azp: string[]; // "azp" in the JWT payload
}> {
// Find the index of the period in the JWT message
const periodIndex = rawJWT.indexOf(".");
Expand Down Expand Up @@ -64,7 +66,7 @@ export async function generateJWTVerifierInputs(
// Find the starting indices of the required substrings
const jwtTypStartIndex = header.indexOf('"typ":"JWT"');
const jwtAlgStartIndex = header.indexOf('"alg":"RS256"');
const commandStartIndex = payload.indexOf('"command":');
const azpKeyStartIndex = payload.indexOf('"azp":');

return {
message: Uint8ArrayToCharArray(messagePadded),
Expand All @@ -74,6 +76,7 @@ export async function generateJWTVerifierInputs(
periodIndex: periodIndex.toString(),
jwtTypStartIndex: jwtTypStartIndex.toString(),
jwtAlgStartIndex: jwtAlgStartIndex.toString(),
commandStartIndex: commandStartIndex.toString(),
azpKeyStartIndex: azpKeyStartIndex.toString(),
azp: Uint8ArrayToCharArray(Buffer.from(params.azp || "", "utf-8")),
};
}
3 changes: 2 additions & 1 deletion packages/helpers/tests/input-generators.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ describe("Generate JWT Verifier Inputs", () => {
expect(inputs.periodIndex).toBeDefined();
expect(inputs.jwtTypStartIndex).toBeDefined();
expect(inputs.jwtAlgStartIndex).toBeDefined();
expect(inputs.commandStartIndex).toBeDefined();
expect(inputs.azpKeyStartIndex).toBeDefined();
expect(inputs.azp).toBeInstanceOf(Array);
});

it("should throw an error for an invalid JWT", async () => {
Expand Down

0 comments on commit 9aee0fe

Please sign in to comment.