From ae2de8fc7a5134e8539c8c7d8cd05273670c23cd Mon Sep 17 00:00:00 2001 From: Kyryl Riabov Date: Fri, 2 Aug 2024 19:42:42 +0300 Subject: [PATCH 1/5] Fixed bug with the linearization of public signals --- src/core/CircuitArtifactGenerator.ts | 94 ++++++++++++++++++++++- src/core/CircuitTypesGenerator.ts | 16 +++- src/core/ZkitTSGenerator.ts | 25 +++++- src/core/templates/circuit-wrapper.ts.ejs | 23 +++++- src/core/templates/utils.ts | 25 ++++++ src/types/circuitArtifact.ts | 5 +- src/types/typesGenerator.ts | 2 + test/CircuitProofGeneration.test.ts | 40 +++++++++- test/fixture/auth/Matrix.circom | 22 ++++++ test/helpers/generateTypes.ts | 1 + 10 files changed, 236 insertions(+), 17 deletions(-) create mode 100644 src/core/templates/utils.ts create mode 100644 test/fixture/auth/Matrix.circom diff --git a/src/core/CircuitArtifactGenerator.ts b/src/core/CircuitArtifactGenerator.ts index abe8afe..6f80f34 100644 --- a/src/core/CircuitArtifactGenerator.ts +++ b/src/core/CircuitArtifactGenerator.ts @@ -95,6 +95,7 @@ export default class CircuitArtifactGenerator { }; const template = this._findTemplateForCircuit(ast.circomCompilerOutput, circuitArtifact.circuitName); + const templateArgs = this.getTemplateArgs(ast.circomCompilerOutput[0].main_component![1].Call.args, template.args); for (const statement of template.body.Block.stmts) { if ( @@ -105,12 +106,29 @@ export default class CircuitArtifactGenerator { continue; } + const dimensions = this.resolveDimension(statement.InitializationBlock.initializations[0].Declaration.dimensions); + const resolvedDimensions = dimensions.map((dimension: any) => { + if (typeof dimension === "string") { + const templateArg = templateArgs[dimension]; + + if (!templateArg) { + throw new Error( + `The template argument ${dimension} is missing in the circuit ${circuitArtifact.circuitName}`, + ); + } + + return Number(templateArg); + } + + return Number(dimension); + }); + const signal: Signal = { type: statement.InitializationBlock.xtype.Signal[0] as SignalType, internalType: this._getInternalType(statement.InitializationBlock.initializations[0].Declaration), visibility: this._getSignalVisibility(ast.circomCompilerOutput[0], statement), name: statement.InitializationBlock.initializations[0].Declaration.name, - dimensions: statement.InitializationBlock.initializations[0].Declaration.dimensions.length, + dimensions: resolvedDimensions, }; circuitArtifact.signals.push(signal); @@ -119,6 +137,80 @@ export default class CircuitArtifactGenerator { return circuitArtifact; } + private getTemplateArgs(args: string[], names: any[]): Record { + if (args.length === 0) { + return {}; + } + + const result: Record = {}; + + for (let i = 0; i < args.length; i++) { + const argObj = (args[i] as any)["Number"]; + + result[names[i]] = BigInt(this.resolveNumber(argObj)); + } + + return result; + } + + private resolveVariable(variableObj: any) { + if (!variableObj || !variableObj.name) { + throw new Error(`The argument ${variableObj} is not a variable`); + } + + return variableObj.name; + } + + private resolveNumber(numberObj: any) { + if (!numberObj || !numberObj.length || numberObj.length < 2) { + throw new Error(`The argument ${numberObj} is not a number`); + } + + if (!numberObj[1] || !numberObj[1].length || numberObj[1].length < 2) { + throw new Error(`The argument ${numberObj} is of unexpected format`); + } + + const actualArg = numberObj[1][1]; + + if (!actualArg || !actualArg.length || numberObj[1].length < 1) { + throw new Error(`The argument ${numberObj} is of unexpected format`); + } + + return actualArg[0]; + } + + private resolveDimension(dimensions: number[]): number[] { + const result: number[] = []; + + for (const dimension of dimensions) { + if (dimension === 0) { + result.push(0); + + continue; + } + + const numberObj = (dimension as any)["Number"]; + const variableObj = (dimension as any)["Variable"]; + + if ( + (numberObj !== undefined && variableObj !== undefined) || + (numberObj === undefined && variableObj === undefined) + ) { + throw new Error(`The dimension ${dimension} is of unexpected format`); + } + + if (numberObj) { + result.push(this.resolveNumber(numberObj)); + + continue; + } + + result.push(this.resolveVariable(variableObj)); + } + + return result; + } + /** * Cleans the artifacts directory by removing all files and subdirectories. */ diff --git a/src/core/CircuitTypesGenerator.ts b/src/core/CircuitTypesGenerator.ts index 952350d..6be65fa 100644 --- a/src/core/CircuitTypesGenerator.ts +++ b/src/core/CircuitTypesGenerator.ts @@ -89,7 +89,8 @@ export class CircuitTypesGenerator extends ZkitTSGenerator { recursive: true, }); - const preparedNode = await this._returnTSDefinitionByArtifact(circuitArtifacts[i]); + const pathToGeneratedFile = path.join(this._projectRoot, this.getOutputTypesDir(), circuitTypePath); + const preparedNode = await this._returnTSDefinitionByArtifact(circuitArtifacts[i], pathToGeneratedFile); this._saveFileContent(circuitTypePath, preparedNode); @@ -100,6 +101,11 @@ export class CircuitTypesGenerator extends ZkitTSGenerator { } await this._resolveTypePaths(typePathsToResolve); + + // copy utils to types output dir + const utilsDirPath = path.join(this._projectRoot, this.getOutputTypesDir()); + fs.mkdirSync(utilsDirPath, { recursive: true }); + fs.copyFileSync(path.join(__dirname, "templates", "utils.ts"), path.join(utilsDirPath, "utils.ts")); } /** @@ -246,10 +252,14 @@ export class CircuitTypesGenerator extends ZkitTSGenerator { * ``` * * @param {CircuitArtifact} circuitArtifact - The circuit artifact for which the TypeScript bindings are generated. + * @param pathToGeneratedFile - The path to the generated file. * @returns {string} The relative to the TYPES_DIR path to the generated file. */ - private async _returnTSDefinitionByArtifact(circuitArtifact: CircuitArtifact): Promise { - return await this._genCircuitWrapperClassContent(circuitArtifact); + private async _returnTSDefinitionByArtifact( + circuitArtifact: CircuitArtifact, + pathToGeneratedFile: string, + ): Promise { + return await this._genCircuitWrapperClassContent(circuitArtifact, pathToGeneratedFile); } /** diff --git a/src/core/ZkitTSGenerator.ts b/src/core/ZkitTSGenerator.ts index 27afb0e..1367f00 100644 --- a/src/core/ZkitTSGenerator.ts +++ b/src/core/ZkitTSGenerator.ts @@ -64,7 +64,10 @@ export default class ZkitTSGenerator extends BaseTSGenerator { .join("."); } - protected async _genCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise { + protected async _genCircuitWrapperClassContent( + circuitArtifact: CircuitArtifact, + pathToGeneratedFile: string, + ): Promise { const template = fs.readFileSync(path.join(__dirname, "templates", "circuit-wrapper.ts.ejs"), "utf8"); let outputCounter: number = 0; @@ -73,7 +76,11 @@ export default class ZkitTSGenerator extends BaseTSGenerator { const privateInputs: Inputs[] = circuitArtifact.signals .filter((signal) => signal.type != SignalTypeNames.Output) .map((signal) => { - return { name: signal.name, dimensions: "[]".repeat(signal.dimensions) }; + return { + name: signal.name, + dimensions: "[]".repeat(signal.dimensions.length), + dimensionsArray: new Array(signal.dimensions).join(", "), + }; }); for (const signal of circuitArtifact.signals) { @@ -82,15 +89,24 @@ export default class ZkitTSGenerator extends BaseTSGenerator { } if (signal.type === SignalTypeNames.Output) { - publicInputs.splice(outputCounter, 0, { name: signal.name, dimensions: "[]".repeat(signal.dimensions) }); + publicInputs.splice(outputCounter, 0, { + name: signal.name, + dimensions: "[]".repeat(signal.dimensions.length), + dimensionsArray: new Array(signal.dimensions).join(", "), + }); outputCounter++; continue; } - publicInputs.push({ name: signal.name, dimensions: "[]".repeat(signal.dimensions) }); + publicInputs.push({ + name: signal.name, + dimensions: "[]".repeat(signal.dimensions.length), + dimensionsArray: new Array(signal.dimensions).join(", "), + }); } + const pathToUtils = path.join(this._projectRoot, this.getOutputTypesDir(), "utils"); const templateParams: WrapperTemplateParams = { circuitClassName: this._getCircuitName(circuitArtifact), publicInputsTypeName: this._getTypeName(circuitArtifact, "Public"), @@ -99,6 +115,7 @@ export default class ZkitTSGenerator extends BaseTSGenerator { privateInputs, proofTypeName: this._getTypeName(circuitArtifact, "Proof"), privateInputsTypeName: this._getTypeName(circuitArtifact, "Private"), + pathToUtils: path.relative(path.dirname(pathToGeneratedFile), pathToUtils), }; return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }); diff --git a/src/core/templates/circuit-wrapper.ts.ejs b/src/core/templates/circuit-wrapper.ts.ejs index c5a30b9..0f26a1c 100644 --- a/src/core/templates/circuit-wrapper.ts.ejs +++ b/src/core/templates/circuit-wrapper.ts.ejs @@ -7,6 +7,8 @@ import { PublicSignals, } from "@solarity/zkit"; +import { flatten, reshape } from "<%= pathToUtils %>"; + export type <%= privateInputsTypeName %> = { <% for (let i = 0; i < privateInputs.length; i++) { -%> <%= privateInputs[i].name %>: NumberLike <%= privateInputs[i].dimensions %>; @@ -71,11 +73,24 @@ export class <%= circuitClassName %> extends CircuitZKit { ]; } + public getSignalDimensions(name: string): number[] { + switch (name) { + <% for (let i = 0; i < publicInputs.length; i++) { -%> + case "<%= publicInputs[i].name %>": return [<%= publicInputs[i].dimensionsArray %>]; + <% } -%> + default: throw new Error(`Unknown signal name: ${name}`); + } + } + private _normalizePublicSignals(publicSignals: PublicSignals): <%= publicInputsTypeName %> { const signalNames = this.getSignalNames(); - return signalNames.reduce((acc: any, signalName, index) => { - acc[signalName] = publicSignals[index]; + let index = 0; + return signalNames.reduce((acc: any, signalName) => { + const dimensions = this.getSignalDimensions(signalName); + const size = dimensions.reduce((a, b) => a * b, 1); + acc[signalName] = reshape(publicSignals.slice(index, index + size), dimensions); + index += size; return acc; }, {}); } @@ -83,7 +98,9 @@ export class <%= circuitClassName %> extends CircuitZKit { private _denormalizePublicSignals(publicSignals: <%= publicInputsTypeName %>): PublicSignals { const signalNames = this.getSignalNames(); - return signalNames.map((signalName) => (publicSignals as any)[signalName]); + return signalNames.reduce((acc: any[], signalName) => { + return acc.concat(flatten(publicSignals[signalName])); + }, []); } } diff --git a/src/core/templates/utils.ts b/src/core/templates/utils.ts new file mode 100644 index 0000000..028c8f4 --- /dev/null +++ b/src/core/templates/utils.ts @@ -0,0 +1,25 @@ +export function reshape(array: number[], dimensions: number[]): any { + if (dimensions.length === 0) { + return array[0]; + } + + const [first, ...rest] = dimensions; + const size = rest.reduce((a, b) => a * b, 1); + + const result = []; + for (let i = 0; i < first; i++) { + result.push(reshape(array.slice(i * size, (i + 1) * size), rest)); + } + + return result; +} + +export function flatten(array: any): number[] { + if (!Array.isArray(array)) { + return [array]; + } + + return array.reduce((acc, value) => { + return acc.concat(flatten(value)); + }, []); +} diff --git a/src/types/circuitArtifact.ts b/src/types/circuitArtifact.ts index 120d559..d9ef10f 100644 --- a/src/types/circuitArtifact.ts +++ b/src/types/circuitArtifact.ts @@ -26,12 +26,13 @@ export interface CircuitArtifact { * @param {SignalType} type - The type of the signal (possible values: `Input`, `Output`). * @param {string} visibility - The visibility of the signal (possible values: `public`, `private`). * @param {string} internalType - The internal type of the signal (only possible value: `bigint`). - * @param {number} dimensions - The number of dimensions of the signal. If the signal is a scalar, the value is `0`. + * @param {number} dimensions - The array of dimensions of the signal. If the signal is a scalar, the value is `[]`. + * For example for a signal a[2][3], the value is `[2, 3]`. */ export interface Signal { name: string; type: SignalType; visibility: SignalVisibility; internalType: string; - dimensions: number; + dimensions: number[]; } diff --git a/src/types/typesGenerator.ts b/src/types/typesGenerator.ts index dc9dec1..133f452 100644 --- a/src/types/typesGenerator.ts +++ b/src/types/typesGenerator.ts @@ -8,6 +8,7 @@ export interface ArtifactWithPath { export interface Inputs { name: string; dimensions: string; + dimensionsArray: string; } export interface WrapperTemplateParams { @@ -18,6 +19,7 @@ export interface WrapperTemplateParams { proofTypeName: string; privateInputsTypeName: string; circuitClassName: string; + pathToUtils: string; } export interface CircuitClass { diff --git a/test/CircuitProofGeneration.test.ts b/test/CircuitProofGeneration.test.ts index c646884..ee29cb5 100644 --- a/test/CircuitProofGeneration.test.ts +++ b/test/CircuitProofGeneration.test.ts @@ -23,27 +23,59 @@ describe("Circuit Proof Generation", function () { ], }); - const config: CircuitZKitConfig = { + const basicConfig: CircuitZKitConfig = { circuitName: "Basic", circuitArtifactsPath: "test/cache/Basic", verifierDirPath: "test/cache", }; + const matrixConfig: CircuitZKitConfig = { + circuitName: "Matrix", + circuitArtifactsPath: "test/cache/Matrix", + verifierDirPath: "test/cache", + }; + beforeEach(async () => { const preprocessor = await generateAST("test/fixture", astDir, true, [], []); await circuitTypesGenerator.generateTypes(); - await preprocessor.circuitCompiler.compileCircuit("test/fixture/Basic.circom", config.circuitArtifactsPath); + await preprocessor.circuitCompiler.compileCircuit("test/fixture/Basic.circom", basicConfig.circuitArtifactsPath); + await preprocessor.circuitCompiler.compileCircuit( + "test/fixture/auth/Matrix.circom", + matrixConfig.circuitArtifactsPath, + ); }); - it("should generate and verify proof", async () => { + it("should generate and verify proof for Basic.circom", async () => { const object = await circuitTypesGenerator.getCircuitObject("test/fixture/Basic.circom:Multiplier2"); - const circuit = new object(config); + const circuit = new object(basicConfig); const proof = await circuit.generateProof({ in1: 2, in2: 3 }); expect(await circuit.verifyProof(proof)).to.be.true; }); + it("should generate and verify proof for Matrix.circom", async () => { + const object = await circuitTypesGenerator.getCircuitObject("test/fixture/auth/Matrix.circom:Matrix"); + + const circuit = new object(matrixConfig); + + const proof = await circuit.generateProof({ + a: [ + [1n, 2n, 3n], + [1n, 2n, 3n], + [1n, 2n, 3n], + ], + b: [ + [1n, 2n, 3n], + [1n, 2n, 3n], + [1n, 2n, 3n], + ], + c: 9n, + }); + + expect(await circuit.verifyProof(proof)).to.be.true; + }); + it("should correctly import all of the zktype objects", async () => { new (await circuitTypesGenerator.getCircuitObject("test/fixture/Basic.circom:Multiplier2"))(); new (await circuitTypesGenerator.getCircuitObject("test/fixture/auth/BasicInAuth.circom:Multiplier2"))(); diff --git a/test/fixture/auth/Matrix.circom b/test/fixture/auth/Matrix.circom new file mode 100644 index 0000000..37de474 --- /dev/null +++ b/test/fixture/auth/Matrix.circom @@ -0,0 +1,22 @@ +pragma circom 2.1.8; + +template Matrix (n) { + signal input a[3][3]; + signal input b[3][3]; + + signal input c; + + signal unused <-- a[0][0] * b[0][0]; + + signal output d[3][3]; + signal output e[3][3]; + + for (var i = 0; i < 2; i++) { + for (var j = 0; j < 2; j++) { + d[i][j] <== a[i][j] * b[i][j] + c; + e[i][j] <== a[i][j] * b[i][j]; + } + } +} + +component main {public [a, b]} = Matrix(12); diff --git a/test/helpers/generateTypes.ts b/test/helpers/generateTypes.ts index 736f3aa..356931a 100644 --- a/test/helpers/generateTypes.ts +++ b/test/helpers/generateTypes.ts @@ -15,6 +15,7 @@ const circuitTypesGenerator = new CircuitTypesGenerator({ "test/cache/circuits-ast/lib/BasicInLib.json", "test/cache/circuits-ast/auth/EMultiplier.json", "test/cache/circuits-ast/auth/BasicInAuth.json", + "test/cache/circuits-ast/auth/Matrix.json", ], }); From fcc1c143ea60be3c2a95079077d48542fe2483ea Mon Sep 17 00:00:00 2001 From: Kyryl Riabov Date: Fri, 2 Aug 2024 22:26:40 +0300 Subject: [PATCH 2/5] Changed to Public signals to NumberLike --- src/core/templates/circuit-wrapper.ts.ejs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/templates/circuit-wrapper.ts.ejs b/src/core/templates/circuit-wrapper.ts.ejs index 0f26a1c..5f0b55e 100644 --- a/src/core/templates/circuit-wrapper.ts.ejs +++ b/src/core/templates/circuit-wrapper.ts.ejs @@ -17,7 +17,7 @@ export type <%= privateInputsTypeName %> = { export type <%= publicInputsTypeName %> = { <% for (let i = 0; i < publicInputs.length; i++) { -%> - <%= publicInputs[i].name %>: NumericString <%= publicInputs[i].dimensions %>; + <%= publicInputs[i].name %>: NumberLike <%= publicInputs[i].dimensions %>; <% } -%> } From ca97f4902d8bb0dd4385db901942a1661b715bb0 Mon Sep 17 00:00:00 2001 From: Kyryl Riabov Date: Sat, 3 Aug 2024 17:30:11 +0300 Subject: [PATCH 3/5] Used flatMap --- src/core/templates/utils.ts | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/core/templates/utils.ts b/src/core/templates/utils.ts index 028c8f4..d3d6807 100644 --- a/src/core/templates/utils.ts +++ b/src/core/templates/utils.ts @@ -15,11 +15,5 @@ export function reshape(array: number[], dimensions: number[]): any { } export function flatten(array: any): number[] { - if (!Array.isArray(array)) { - return [array]; - } - - return array.reduce((acc, value) => { - return acc.concat(flatten(value)); - }, []); + return Array.isArray(array) ? array.flatMap((array) => flatten(array)) : array; } From 54cd525bd59b1947af4a8644ae6cfba5a6e3f619 Mon Sep 17 00:00:00 2001 From: Kyryl Riabov Date: Sat, 3 Aug 2024 17:41:02 +0300 Subject: [PATCH 4/5] Reduced code duplication --- src/core/templates/circuit-wrapper.ts.ejs | 19 +++------------ src/core/templates/utils.ts | 29 +++++++++++++++++++++-- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/core/templates/circuit-wrapper.ts.ejs b/src/core/templates/circuit-wrapper.ts.ejs index 5f0b55e..f2a372f 100644 --- a/src/core/templates/circuit-wrapper.ts.ejs +++ b/src/core/templates/circuit-wrapper.ts.ejs @@ -7,7 +7,7 @@ import { PublicSignals, } from "@solarity/zkit"; -import { flatten, reshape } from "<%= pathToUtils %>"; +import { normalizePublicSignals, denormalizePublicSignals } from "<%= pathToUtils %>"; export type <%= privateInputsTypeName %> = { <% for (let i = 0; i < privateInputs.length; i++) { -%> @@ -83,24 +83,11 @@ export class <%= circuitClassName %> extends CircuitZKit { } private _normalizePublicSignals(publicSignals: PublicSignals): <%= publicInputsTypeName %> { - const signalNames = this.getSignalNames(); - - let index = 0; - return signalNames.reduce((acc: any, signalName) => { - const dimensions = this.getSignalDimensions(signalName); - const size = dimensions.reduce((a, b) => a * b, 1); - acc[signalName] = reshape(publicSignals.slice(index, index + size), dimensions); - index += size; - return acc; - }, {}); + return normalizePublicSignals(publicSignals, this.getSignalNames(), this.getSignalDimensions); } private _denormalizePublicSignals(publicSignals: <%= publicInputsTypeName %>): PublicSignals { - const signalNames = this.getSignalNames(); - - return signalNames.reduce((acc: any[], signalName) => { - return acc.concat(flatten(publicSignals[signalName])); - }, []); + return denormalizePublicSignals(publicSignals, this.getSignalNames()); } } diff --git a/src/core/templates/utils.ts b/src/core/templates/utils.ts index d3d6807..edd1fdb 100644 --- a/src/core/templates/utils.ts +++ b/src/core/templates/utils.ts @@ -1,4 +1,29 @@ -export function reshape(array: number[], dimensions: number[]): any { +import { PublicSignals } from "@solarity/zkit"; + +export function normalizePublicSignals( + publicSignals: any[], + signalNames: string[], + getSignalDimensions: (name: string) => number[], +): any { + let index = 0; + return signalNames.reduce((acc: any, signalName) => { + const dimensions = getSignalDimensions(signalName); + const size = dimensions.reduce((a, b) => a * b, 1); + + acc[signalName] = reshape(publicSignals.slice(index, index + size), dimensions); + index += size; + + return acc; + }, {}); +} + +export function denormalizePublicSignals(publicSignals: any, signalNames: string[]): PublicSignals { + return signalNames.reduce((acc: any[], signalName) => { + return acc.concat(flatten(publicSignals[signalName])); + }, []); +} + +function reshape(array: number[], dimensions: number[]): any { if (dimensions.length === 0) { return array[0]; } @@ -14,6 +39,6 @@ export function reshape(array: number[], dimensions: number[]): any { return result; } -export function flatten(array: any): number[] { +function flatten(array: any): number[] { return Array.isArray(array) ? array.flatMap((array) => flatten(array)) : array; } From d698ddaf3636e1b60b1a092d791ea298a5ebcb9a Mon Sep 17 00:00:00 2001 From: Kyryl Riabov Date: Mon, 5 Aug 2024 13:52:01 +0300 Subject: [PATCH 5/5] Updated versions --- package-lock.json | 4 ++-- package.json | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/package-lock.json b/package-lock.json index 0c5da4e..3f809e0 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@solarity/zktype", - "version": "0.2.4", + "version": "0.2.5", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@solarity/zktype", - "version": "0.2.4", + "version": "0.2.5", "license": "MIT", "dependencies": { "ejs": "3.1.10", diff --git a/package.json b/package.json index fe251ae..95ea875 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@solarity/zktype", - "version": "0.2.4", + "version": "0.2.5", "description": "Unleash TypeScript bindings for Circom circuits", "main": "dist/index.js", "types": "dist/index.d.ts",