Skip to content

Commit

Permalink
Fix bug with the linearization of public signals (#9)
Browse files Browse the repository at this point in the history
* Fixed bug with the linearization of public signals

* Changed to Public signals to NumberLike

* Used flatMap

* Reduced code duplication

* Updated versions
  • Loading branch information
KyrylR authored Aug 5, 2024
1 parent 579cac6 commit 5b9b228
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 27 deletions.
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@solarity/zktype",
"version": "0.2.4",
"version": "0.2.5",
"description": "Unleash TypeScript bindings for Circom circuits",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down
94 changes: 93 additions & 1 deletion src/core/CircuitArtifactGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ export default class CircuitArtifactGenerator {
};

const template = this._findTemplateForCircuit(ast.circomCompilerOutput, circuitArtifact.circuitName);
const templateArgs = this.getTemplateArgs(ast.circomCompilerOutput[0].main_component![1].Call.args, template.args);

for (const statement of template.body.Block.stmts) {
if (
Expand All @@ -105,12 +106,29 @@ export default class CircuitArtifactGenerator {
continue;
}

const dimensions = this.resolveDimension(statement.InitializationBlock.initializations[0].Declaration.dimensions);
const resolvedDimensions = dimensions.map((dimension: any) => {
if (typeof dimension === "string") {
const templateArg = templateArgs[dimension];

if (!templateArg) {
throw new Error(
`The template argument ${dimension} is missing in the circuit ${circuitArtifact.circuitName}`,
);
}

return Number(templateArg);
}

return Number(dimension);
});

const signal: Signal = {
type: statement.InitializationBlock.xtype.Signal[0] as SignalType,
internalType: this._getInternalType(statement.InitializationBlock.initializations[0].Declaration),
visibility: this._getSignalVisibility(ast.circomCompilerOutput[0], statement),
name: statement.InitializationBlock.initializations[0].Declaration.name,
dimensions: statement.InitializationBlock.initializations[0].Declaration.dimensions.length,
dimensions: resolvedDimensions,
};

circuitArtifact.signals.push(signal);
Expand All @@ -119,6 +137,80 @@ export default class CircuitArtifactGenerator {
return circuitArtifact;
}

private getTemplateArgs(args: string[], names: any[]): Record<string, bigint> {
if (args.length === 0) {
return {};
}

const result: Record<string, bigint> = {};

for (let i = 0; i < args.length; i++) {
const argObj = (args[i] as any)["Number"];

result[names[i]] = BigInt(this.resolveNumber(argObj));
}

return result;
}

private resolveVariable(variableObj: any) {
if (!variableObj || !variableObj.name) {
throw new Error(`The argument ${variableObj} is not a variable`);
}

return variableObj.name;
}

private resolveNumber(numberObj: any) {
if (!numberObj || !numberObj.length || numberObj.length < 2) {
throw new Error(`The argument ${numberObj} is not a number`);
}

if (!numberObj[1] || !numberObj[1].length || numberObj[1].length < 2) {
throw new Error(`The argument ${numberObj} is of unexpected format`);
}

const actualArg = numberObj[1][1];

if (!actualArg || !actualArg.length || numberObj[1].length < 1) {
throw new Error(`The argument ${numberObj} is of unexpected format`);
}

return actualArg[0];
}

private resolveDimension(dimensions: number[]): number[] {
const result: number[] = [];

for (const dimension of dimensions) {
if (dimension === 0) {
result.push(0);

continue;
}

const numberObj = (dimension as any)["Number"];
const variableObj = (dimension as any)["Variable"];

if (
(numberObj !== undefined && variableObj !== undefined) ||
(numberObj === undefined && variableObj === undefined)
) {
throw new Error(`The dimension ${dimension} is of unexpected format`);
}

if (numberObj) {
result.push(this.resolveNumber(numberObj));

continue;
}

result.push(this.resolveVariable(variableObj));
}

return result;
}

/**
* Cleans the artifacts directory by removing all files and subdirectories.
*/
Expand Down
16 changes: 13 additions & 3 deletions src/core/CircuitTypesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
recursive: true,
});

const preparedNode = await this._returnTSDefinitionByArtifact(circuitArtifacts[i]);
const pathToGeneratedFile = path.join(this._projectRoot, this.getOutputTypesDir(), circuitTypePath);
const preparedNode = await this._returnTSDefinitionByArtifact(circuitArtifacts[i], pathToGeneratedFile);

this._saveFileContent(circuitTypePath, preparedNode);

Expand All @@ -100,6 +101,11 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
}

await this._resolveTypePaths(typePathsToResolve);

// copy utils to types output dir
const utilsDirPath = path.join(this._projectRoot, this.getOutputTypesDir());
fs.mkdirSync(utilsDirPath, { recursive: true });
fs.copyFileSync(path.join(__dirname, "templates", "utils.ts"), path.join(utilsDirPath, "utils.ts"));
}

/**
Expand Down Expand Up @@ -246,10 +252,14 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
* ```
*
* @param {CircuitArtifact} circuitArtifact - The circuit artifact for which the TypeScript bindings are generated.
* @param pathToGeneratedFile - The path to the generated file.
* @returns {string} The relative to the TYPES_DIR path to the generated file.
*/
private async _returnTSDefinitionByArtifact(circuitArtifact: CircuitArtifact): Promise<string> {
return await this._genCircuitWrapperClassContent(circuitArtifact);
private async _returnTSDefinitionByArtifact(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<string> {
return await this._genCircuitWrapperClassContent(circuitArtifact, pathToGeneratedFile);
}

/**
Expand Down
25 changes: 21 additions & 4 deletions src/core/ZkitTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
.join(".");
}

protected async _genCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise<string> {
protected async _genCircuitWrapperClassContent(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<string> {
const template = fs.readFileSync(path.join(__dirname, "templates", "circuit-wrapper.ts.ejs"), "utf8");

let outputCounter: number = 0;
Expand All @@ -73,7 +76,11 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
const privateInputs: Inputs[] = circuitArtifact.signals
.filter((signal) => signal.type != SignalTypeNames.Output)
.map((signal) => {
return { name: signal.name, dimensions: "[]".repeat(signal.dimensions) };
return {
name: signal.name,
dimensions: "[]".repeat(signal.dimensions.length),
dimensionsArray: new Array(signal.dimensions).join(", "),
};
});

for (const signal of circuitArtifact.signals) {
Expand All @@ -82,15 +89,24 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
}

if (signal.type === SignalTypeNames.Output) {
publicInputs.splice(outputCounter, 0, { name: signal.name, dimensions: "[]".repeat(signal.dimensions) });
publicInputs.splice(outputCounter, 0, {
name: signal.name,
dimensions: "[]".repeat(signal.dimensions.length),
dimensionsArray: new Array(signal.dimensions).join(", "),
});

outputCounter++;
continue;
}

publicInputs.push({ name: signal.name, dimensions: "[]".repeat(signal.dimensions) });
publicInputs.push({
name: signal.name,
dimensions: "[]".repeat(signal.dimensions.length),
dimensionsArray: new Array(signal.dimensions).join(", "),
});
}

const pathToUtils = path.join(this._projectRoot, this.getOutputTypesDir(), "utils");
const templateParams: WrapperTemplateParams = {
circuitClassName: this._getCircuitName(circuitArtifact),
publicInputsTypeName: this._getTypeName(circuitArtifact, "Public"),
Expand All @@ -99,6 +115,7 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
privateInputs,
proofTypeName: this._getTypeName(circuitArtifact, "Proof"),
privateInputsTypeName: this._getTypeName(circuitArtifact, "Private"),
pathToUtils: path.relative(path.dirname(pathToGeneratedFile), pathToUtils),
};

return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
Expand Down
24 changes: 14 additions & 10 deletions src/core/templates/circuit-wrapper.ts.ejs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import {
PublicSignals,
} from "@solarity/zkit";

import { normalizePublicSignals, denormalizePublicSignals } from "<%= pathToUtils %>";

export type <%= privateInputsTypeName %> = {
<% for (let i = 0; i < privateInputs.length; i++) { -%>
<%= privateInputs[i].name %>: NumberLike <%= privateInputs[i].dimensions %>;
Expand All @@ -15,7 +17,7 @@ export type <%= privateInputsTypeName %> = {

export type <%= publicInputsTypeName %> = {
<% for (let i = 0; i < publicInputs.length; i++) { -%>
<%= publicInputs[i].name %>: NumericString <%= publicInputs[i].dimensions %>;
<%= publicInputs[i].name %>: NumberLike <%= publicInputs[i].dimensions %>;
<% } -%>
}

Expand Down Expand Up @@ -71,19 +73,21 @@ export class <%= circuitClassName %> extends CircuitZKit {
];
}

private _normalizePublicSignals(publicSignals: PublicSignals): <%= publicInputsTypeName %> {
const signalNames = this.getSignalNames();
public getSignalDimensions(name: string): number[] {
switch (name) {
<% for (let i = 0; i < publicInputs.length; i++) { -%>
case "<%= publicInputs[i].name %>": return [<%= publicInputs[i].dimensionsArray %>];
<% } -%>
default: throw new Error(`Unknown signal name: ${name}`);
}
}

return signalNames.reduce((acc: any, signalName, index) => {
acc[signalName] = publicSignals[index];
return acc;
}, {});
private _normalizePublicSignals(publicSignals: PublicSignals): <%= publicInputsTypeName %> {
return normalizePublicSignals(publicSignals, this.getSignalNames(), this.getSignalDimensions);
}

private _denormalizePublicSignals(publicSignals: <%= publicInputsTypeName %>): PublicSignals {
const signalNames = this.getSignalNames();

return signalNames.map((signalName) => (publicSignals as any)[signalName]);
return denormalizePublicSignals(publicSignals, this.getSignalNames());
}
}

Expand Down
44 changes: 44 additions & 0 deletions src/core/templates/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { PublicSignals } from "@solarity/zkit";

export function normalizePublicSignals(
publicSignals: any[],
signalNames: string[],
getSignalDimensions: (name: string) => number[],
): any {
let index = 0;
return signalNames.reduce((acc: any, signalName) => {
const dimensions = getSignalDimensions(signalName);
const size = dimensions.reduce((a, b) => a * b, 1);

acc[signalName] = reshape(publicSignals.slice(index, index + size), dimensions);
index += size;

return acc;
}, {});
}

export function denormalizePublicSignals(publicSignals: any, signalNames: string[]): PublicSignals {
return signalNames.reduce((acc: any[], signalName) => {
return acc.concat(flatten(publicSignals[signalName]));
}, []);
}

function reshape(array: number[], dimensions: number[]): any {
if (dimensions.length === 0) {
return array[0];
}

const [first, ...rest] = dimensions;
const size = rest.reduce((a, b) => a * b, 1);

const result = [];
for (let i = 0; i < first; i++) {
result.push(reshape(array.slice(i * size, (i + 1) * size), rest));
}

return result;
}

function flatten(array: any): number[] {
return Array.isArray(array) ? array.flatMap((array) => flatten(array)) : array;
}
5 changes: 3 additions & 2 deletions src/types/circuitArtifact.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ export interface CircuitArtifact {
* @param {SignalType} type - The type of the signal (possible values: `Input`, `Output`).
* @param {string} visibility - The visibility of the signal (possible values: `public`, `private`).
* @param {string} internalType - The internal type of the signal (only possible value: `bigint`).
* @param {number} dimensions - The number of dimensions of the signal. If the signal is a scalar, the value is `0`.
* @param {number} dimensions - The array of dimensions of the signal. If the signal is a scalar, the value is `[]`.
* For example for a signal a[2][3], the value is `[2, 3]`.
*/
export interface Signal {
name: string;
type: SignalType;
visibility: SignalVisibility;
internalType: string;
dimensions: number;
dimensions: number[];
}
2 changes: 2 additions & 0 deletions src/types/typesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export interface ArtifactWithPath {
export interface Inputs {
name: string;
dimensions: string;
dimensionsArray: string;
}

export interface WrapperTemplateParams {
Expand All @@ -18,6 +19,7 @@ export interface WrapperTemplateParams {
proofTypeName: string;
privateInputsTypeName: string;
circuitClassName: string;
pathToUtils: string;
}

export interface CircuitClass {
Expand Down
Loading

0 comments on commit 5b9b228

Please sign in to comment.