Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with the linearization of public signals #9

Merged
merged 5 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@solarity/zktype",
"version": "0.2.4",
"version": "0.2.5",
"description": "Unleash TypeScript bindings for Circom circuits",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down
94 changes: 93 additions & 1 deletion src/core/CircuitArtifactGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ export default class CircuitArtifactGenerator {
};

const template = this._findTemplateForCircuit(ast.circomCompilerOutput, circuitArtifact.circuitName);
const templateArgs = this.getTemplateArgs(ast.circomCompilerOutput[0].main_component![1].Call.args, template.args);

for (const statement of template.body.Block.stmts) {
if (
Expand All @@ -105,12 +106,29 @@ export default class CircuitArtifactGenerator {
continue;
}

const dimensions = this.resolveDimension(statement.InitializationBlock.initializations[0].Declaration.dimensions);
const resolvedDimensions = dimensions.map((dimension: any) => {
if (typeof dimension === "string") {
const templateArg = templateArgs[dimension];

if (!templateArg) {
throw new Error(
`The template argument ${dimension} is missing in the circuit ${circuitArtifact.circuitName}`,
);
}

return Number(templateArg);
}

return Number(dimension);
});

const signal: Signal = {
type: statement.InitializationBlock.xtype.Signal[0] as SignalType,
internalType: this._getInternalType(statement.InitializationBlock.initializations[0].Declaration),
visibility: this._getSignalVisibility(ast.circomCompilerOutput[0], statement),
name: statement.InitializationBlock.initializations[0].Declaration.name,
dimensions: statement.InitializationBlock.initializations[0].Declaration.dimensions.length,
dimensions: resolvedDimensions,
};

circuitArtifact.signals.push(signal);
Expand All @@ -119,6 +137,80 @@ export default class CircuitArtifactGenerator {
return circuitArtifact;
}

private getTemplateArgs(args: string[], names: any[]): Record<string, bigint> {
if (args.length === 0) {
return {};
}

const result: Record<string, bigint> = {};

for (let i = 0; i < args.length; i++) {
const argObj = (args[i] as any)["Number"];

result[names[i]] = BigInt(this.resolveNumber(argObj));
}

return result;
}

private resolveVariable(variableObj: any) {
if (!variableObj || !variableObj.name) {
throw new Error(`The argument ${variableObj} is not a variable`);
}

return variableObj.name;
}

private resolveNumber(numberObj: any) {
if (!numberObj || !numberObj.length || numberObj.length < 2) {
throw new Error(`The argument ${numberObj} is not a number`);
}

if (!numberObj[1] || !numberObj[1].length || numberObj[1].length < 2) {
throw new Error(`The argument ${numberObj} is of unexpected format`);
}

const actualArg = numberObj[1][1];

if (!actualArg || !actualArg.length || numberObj[1].length < 1) {
throw new Error(`The argument ${numberObj} is of unexpected format`);
}

return actualArg[0];
}

private resolveDimension(dimensions: number[]): number[] {
const result: number[] = [];

for (const dimension of dimensions) {
if (dimension === 0) {
result.push(0);

continue;
}

const numberObj = (dimension as any)["Number"];
const variableObj = (dimension as any)["Variable"];

if (
(numberObj !== undefined && variableObj !== undefined) ||
(numberObj === undefined && variableObj === undefined)
) {
throw new Error(`The dimension ${dimension} is of unexpected format`);
}

if (numberObj) {
result.push(this.resolveNumber(numberObj));

continue;
}

result.push(this.resolveVariable(variableObj));
}

return result;
}

/**
* Cleans the artifacts directory by removing all files and subdirectories.
*/
Expand Down
16 changes: 13 additions & 3 deletions src/core/CircuitTypesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
recursive: true,
});

const preparedNode = await this._returnTSDefinitionByArtifact(circuitArtifacts[i]);
const pathToGeneratedFile = path.join(this._projectRoot, this.getOutputTypesDir(), circuitTypePath);
const preparedNode = await this._returnTSDefinitionByArtifact(circuitArtifacts[i], pathToGeneratedFile);

this._saveFileContent(circuitTypePath, preparedNode);

Expand All @@ -100,6 +101,11 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
}

await this._resolveTypePaths(typePathsToResolve);

// copy utils to types output dir
const utilsDirPath = path.join(this._projectRoot, this.getOutputTypesDir());
fs.mkdirSync(utilsDirPath, { recursive: true });
fs.copyFileSync(path.join(__dirname, "templates", "utils.ts"), path.join(utilsDirPath, "utils.ts"));
}

/**
Expand Down Expand Up @@ -246,10 +252,14 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
* ```
*
* @param {CircuitArtifact} circuitArtifact - The circuit artifact for which the TypeScript bindings are generated.
* @param pathToGeneratedFile - The path to the generated file.
* @returns {string} The relative to the TYPES_DIR path to the generated file.
*/
private async _returnTSDefinitionByArtifact(circuitArtifact: CircuitArtifact): Promise<string> {
return await this._genCircuitWrapperClassContent(circuitArtifact);
private async _returnTSDefinitionByArtifact(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<string> {
return await this._genCircuitWrapperClassContent(circuitArtifact, pathToGeneratedFile);
}

/**
Expand Down
25 changes: 21 additions & 4 deletions src/core/ZkitTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
.join(".");
}

protected async _genCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise<string> {
protected async _genCircuitWrapperClassContent(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<string> {
const template = fs.readFileSync(path.join(__dirname, "templates", "circuit-wrapper.ts.ejs"), "utf8");

let outputCounter: number = 0;
Expand All @@ -73,7 +76,11 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
const privateInputs: Inputs[] = circuitArtifact.signals
.filter((signal) => signal.type != SignalTypeNames.Output)
.map((signal) => {
return { name: signal.name, dimensions: "[]".repeat(signal.dimensions) };
return {
name: signal.name,
dimensions: "[]".repeat(signal.dimensions.length),
dimensionsArray: new Array(signal.dimensions).join(", "),
};
});

for (const signal of circuitArtifact.signals) {
Expand All @@ -82,15 +89,24 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
}

if (signal.type === SignalTypeNames.Output) {
publicInputs.splice(outputCounter, 0, { name: signal.name, dimensions: "[]".repeat(signal.dimensions) });
publicInputs.splice(outputCounter, 0, {
name: signal.name,
dimensions: "[]".repeat(signal.dimensions.length),
dimensionsArray: new Array(signal.dimensions).join(", "),
});

outputCounter++;
continue;
}

publicInputs.push({ name: signal.name, dimensions: "[]".repeat(signal.dimensions) });
publicInputs.push({
name: signal.name,
dimensions: "[]".repeat(signal.dimensions.length),
dimensionsArray: new Array(signal.dimensions).join(", "),
});
}

const pathToUtils = path.join(this._projectRoot, this.getOutputTypesDir(), "utils");
const templateParams: WrapperTemplateParams = {
circuitClassName: this._getCircuitName(circuitArtifact),
publicInputsTypeName: this._getTypeName(circuitArtifact, "Public"),
Expand All @@ -99,6 +115,7 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
privateInputs,
proofTypeName: this._getTypeName(circuitArtifact, "Proof"),
privateInputsTypeName: this._getTypeName(circuitArtifact, "Private"),
pathToUtils: path.relative(path.dirname(pathToGeneratedFile), pathToUtils),
};

return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
Expand Down
24 changes: 14 additions & 10 deletions src/core/templates/circuit-wrapper.ts.ejs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import {
PublicSignals,
} from "@solarity/zkit";

import { normalizePublicSignals, denormalizePublicSignals } from "<%= pathToUtils %>";

export type <%= privateInputsTypeName %> = {
<% for (let i = 0; i < privateInputs.length; i++) { -%>
<%= privateInputs[i].name %>: NumberLike <%= privateInputs[i].dimensions %>;
Expand All @@ -15,7 +17,7 @@ export type <%= privateInputsTypeName %> = {

export type <%= publicInputsTypeName %> = {
<% for (let i = 0; i < publicInputs.length; i++) { -%>
<%= publicInputs[i].name %>: NumericString <%= publicInputs[i].dimensions %>;
<%= publicInputs[i].name %>: NumberLike <%= publicInputs[i].dimensions %>;
<% } -%>
}

Expand Down Expand Up @@ -71,19 +73,21 @@ export class <%= circuitClassName %> extends CircuitZKit {
];
}

private _normalizePublicSignals(publicSignals: PublicSignals): <%= publicInputsTypeName %> {
const signalNames = this.getSignalNames();
public getSignalDimensions(name: string): number[] {
switch (name) {
<% for (let i = 0; i < publicInputs.length; i++) { -%>
case "<%= publicInputs[i].name %>": return [<%= publicInputs[i].dimensionsArray %>];
<% } -%>
default: throw new Error(`Unknown signal name: ${name}`);
}
}

return signalNames.reduce((acc: any, signalName, index) => {
acc[signalName] = publicSignals[index];
return acc;
}, {});
private _normalizePublicSignals(publicSignals: PublicSignals): <%= publicInputsTypeName %> {
return normalizePublicSignals(publicSignals, this.getSignalNames(), this.getSignalDimensions);
}

private _denormalizePublicSignals(publicSignals: <%= publicInputsTypeName %>): PublicSignals {
const signalNames = this.getSignalNames();

return signalNames.map((signalName) => (publicSignals as any)[signalName]);
return denormalizePublicSignals(publicSignals, this.getSignalNames());
}
}

Expand Down
44 changes: 44 additions & 0 deletions src/core/templates/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { PublicSignals } from "@solarity/zkit";

export function normalizePublicSignals(
publicSignals: any[],
signalNames: string[],
getSignalDimensions: (name: string) => number[],
): any {
let index = 0;
return signalNames.reduce((acc: any, signalName) => {
const dimensions = getSignalDimensions(signalName);
const size = dimensions.reduce((a, b) => a * b, 1);

acc[signalName] = reshape(publicSignals.slice(index, index + size), dimensions);
index += size;

return acc;
}, {});
}

export function denormalizePublicSignals(publicSignals: any, signalNames: string[]): PublicSignals {
return signalNames.reduce((acc: any[], signalName) => {
return acc.concat(flatten(publicSignals[signalName]));
}, []);
}

function reshape(array: number[], dimensions: number[]): any {
if (dimensions.length === 0) {
return array[0];
}

const [first, ...rest] = dimensions;
const size = rest.reduce((a, b) => a * b, 1);

const result = [];
for (let i = 0; i < first; i++) {
result.push(reshape(array.slice(i * size, (i + 1) * size), rest));
}

return result;
}

function flatten(array: any): number[] {
return Array.isArray(array) ? array.flatMap((array) => flatten(array)) : array;
}
5 changes: 3 additions & 2 deletions src/types/circuitArtifact.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ export interface CircuitArtifact {
* @param {SignalType} type - The type of the signal (possible values: `Input`, `Output`).
* @param {string} visibility - The visibility of the signal (possible values: `public`, `private`).
* @param {string} internalType - The internal type of the signal (only possible value: `bigint`).
* @param {number} dimensions - The number of dimensions of the signal. If the signal is a scalar, the value is `0`.
* @param {number} dimensions - The array of dimensions of the signal. If the signal is a scalar, the value is `[]`.
* For example for a signal a[2][3], the value is `[2, 3]`.
*/
export interface Signal {
name: string;
type: SignalType;
visibility: SignalVisibility;
internalType: string;
dimensions: number;
dimensions: number[];
}
2 changes: 2 additions & 0 deletions src/types/typesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export interface ArtifactWithPath {
export interface Inputs {
name: string;
dimensions: string;
dimensionsArray: string;
}

export interface WrapperTemplateParams {
Expand All @@ -18,6 +19,7 @@ export interface WrapperTemplateParams {
proofTypeName: string;
privateInputsTypeName: string;
circuitClassName: string;
pathToUtils: string;
}

export interface CircuitClass {
Expand Down
Loading
Loading