diff --git a/build.gradle b/build.gradle index 09dd4c0f35..8a70608865 100644 --- a/build.gradle +++ b/build.gradle @@ -215,6 +215,18 @@ allprojects { useJUnitPlatform() } + task unitTests(type: Test) { + useJUnitPlatform { + excludeTags("CorsetTest") + } + } + + task corsetTests(type: Test) { + useJUnitPlatform { + includeTags("CorsetTest") + } + } + javadoc { options.addStringOption('Xdoclint:all', '-quiet') options.addStringOption('Xwerror', '-html5') diff --git a/src/main/java/net/consensys/linea/zktracer/OpCode.java b/src/main/java/net/consensys/linea/zktracer/OpCode.java index 0625ac49f6..9b8f195b5e 100644 --- a/src/main/java/net/consensys/linea/zktracer/OpCode.java +++ b/src/main/java/net/consensys/linea/zktracer/OpCode.java @@ -37,6 +37,9 @@ public enum OpCode { SGT(0x13), EQ(0x14), ISZERO(0x15), + // mul + MUL(0x02), + EXP(0x0a), // shf SHL(0x1b), SHR(0x1c), diff --git a/src/main/java/net/consensys/linea/zktracer/ZkTracer.java b/src/main/java/net/consensys/linea/zktracer/ZkTracer.java index 6aa8e7e6dd..4f5f1df711 100644 --- a/src/main/java/net/consensys/linea/zktracer/ZkTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/ZkTracer.java @@ -24,6 +24,7 @@ import net.consensys.linea.zktracer.module.ModuleTracer; import net.consensys.linea.zktracer.module.alu.add.AddTracer; import net.consensys.linea.zktracer.module.alu.mod.ModTracer; +import net.consensys.linea.zktracer.module.alu.mul.MulTracer; import net.consensys.linea.zktracer.module.shf.ShfTracer; import net.consensys.linea.zktracer.module.wcp.WcpTracer; @@ -42,7 +43,8 @@ public ZkTracer(final ZkTraceBuilder zkTraceBuilder, final List tr public ZkTracer(final ZkTraceBuilder zkTraceBuilder) { this( zkTraceBuilder, - List.of(new ShfTracer(), new WcpTracer(), new AddTracer(), new ModTracer())); + List.of( + new MulTracer(), new ShfTracer(), new WcpTracer(), new AddTracer(), new ModTracer())); } @Override diff --git a/src/main/java/net/consensys/linea/zktracer/bytes/Bytes16.java b/src/main/java/net/consensys/linea/zktracer/bytes/Bytes16.java index 3a4f7cfcad..9ff1ffc193 100644 --- a/src/main/java/net/consensys/linea/zktracer/bytes/Bytes16.java +++ b/src/main/java/net/consensys/linea/zktracer/bytes/Bytes16.java @@ -132,7 +132,7 @@ static Bytes16 leftPad(Bytes value) { * Right pad a {@link Bytes} value with zero bytes to create a {@link Bytes16}. * * @param value The bytes value pad. - * @return A {@link Bytes16} that exposes the rightw-padded bytes of {@code value}. + * @return A {@link Bytes16} that exposes the right-padded bytes of {@code value}. * @throws IllegalArgumentException if {@code value.size() > 16}. */ static Bytes16 rightPad(Bytes value) { diff --git a/src/main/java/net/consensys/linea/zktracer/module/Util.java b/src/main/java/net/consensys/linea/zktracer/module/Util.java index 8322cfc746..f97153ff01 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/Util.java +++ b/src/main/java/net/consensys/linea/zktracer/module/Util.java @@ -124,4 +124,16 @@ public static UInt256 multiplyRange(Bytes[] range1, Bytes[] range2) { } return sum; } + /** + * Converts a boolean value to a byte (1 for true and 0 for false). + * + * @param b The boolean value to be converted. + * @return A byte representing the input boolean value. + */ + public static byte boolToByte(boolean b) { + if (b) { + return 1; + } + return 0; + } } diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java new file mode 100644 index 0000000000..d512c0264a --- /dev/null +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -0,0 +1,406 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package net.consensys.linea.zktracer.module.alu.mul; + +import static net.consensys.linea.zktracer.module.Util.boolToByte; +import static net.consensys.linea.zktracer.module.Util.byteBits; +import static net.consensys.linea.zktracer.module.Util.getBit; +import static net.consensys.linea.zktracer.module.Util.getOverflow; + +import java.lang.reflect.Array; +import java.math.BigInteger; + +import net.consensys.linea.zktracer.OpCode; +import net.consensys.linea.zktracer.bytes.Bytes16; +import net.consensys.linea.zktracer.bytes.UnsignedByte; +import net.consensys.linea.zktracer.bytestheta.BaseBytes; +import net.consensys.linea.zktracer.bytestheta.BaseTheta; +import org.apache.tuweni.bytes.Bytes; +import org.apache.tuweni.bytes.Bytes32; +import org.apache.tuweni.units.bigints.UInt256; + +public class MulData { + private static final int MMEDIUM = 8; + final OpCode opCode; + final Bytes32 arg1; + final Bytes32 arg2; + + final Bytes16 arg1Hi; + final Bytes16 arg1Lo; + final Bytes16 arg2Hi; + final Bytes16 arg2Lo; + + final boolean tinyBase; + final boolean tinyExponent; + + UInt256 resAcc = + UInt256.ZERO; // accumulator which converges in a series of "square and multiply"'s + UInt256 expAcc = + UInt256.ZERO; // accumulator for doubles and adds of the exponent, resets at some point + + final BaseTheta aBytes; + final BaseTheta bBytes; + BaseTheta cBytes = BaseTheta.fromBytes32(Bytes32.ZERO); + BaseTheta hBytes = BaseTheta.fromBytes32(Bytes32.ZERO); + boolean snm = false; + int index; + Boolean[] bits = new Boolean[8]; + String exponentBits = "0"; + + BaseBytes res; + + public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { + + this.opCode = opCode; + this.arg1 = arg1; + this.arg2 = arg2; + this.aBytes = BaseTheta.fromBytes32(arg1); + this.bBytes = BaseTheta.fromBytes32(arg2); + + arg1Hi = Bytes16.wrap(arg1.slice(0, 16)); + arg1Lo = Bytes16.wrap(arg1.slice(16)); + arg2Hi = Bytes16.wrap(arg2.slice(0, 16)); + arg2Lo = Bytes16.wrap(arg2.slice(16)); + + this.res = getRes(opCode, arg1, arg2); // TODO can we get this from the EVM + + final BigInteger arg1BigInt = UInt256.fromBytes(arg1).toUnsignedBigInteger(); + final BigInteger arg2BigInt = UInt256.fromBytes(arg2).toUnsignedBigInteger(); + + this.tinyBase = isTiny(arg1BigInt); + this.tinyExponent = isTiny(arg2BigInt); + + // initialize bits + for (int i = 0; i < bits.length; i++) { + bits[i] = false; + } + + final Regime regime = getRegime(); + switch (regime) { + case TRIVIAL_MUL -> {} + case NON_TRIVIAL_MUL -> cBytes = BaseTheta.fromBytes32(res.getBytes32()); + case EXPONENT_ZERO_RESULT -> setArraysForZeroResultCase(); + case EXPONENT_NON_ZERO_RESULT -> { + this.exponentBits = arg2.toBigInteger().toString(); + snm = false; + } + case IOTA -> throw new RuntimeException("alu/mul regime was never set"); + } + } + + private static BaseBytes getRes(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { + + return switch (opCode) { + case MUL -> BaseBytes.fromBytes32(UInt256.fromBytes(arg1).multiply(UInt256.fromBytes(arg2))); + case EXP -> BaseBytes.fromBytes32(UInt256.fromBytes(arg1).pow(UInt256.fromBytes(arg2))); + default -> throw new RuntimeException("MUL module was given wrong opcode"); + }; + } + + private void setArraysForZeroResultCase() { + int nu = twoAdicity(arg1); + + if (nu >= 128) { + return; + } + + Bytes ones = Bytes.repeat((byte) 1, 8); + Bytes bytes; + + if (128 > nu && nu >= 64) { + bytes = aBytes.get(1); + } else { + for (int i = 0; i < 8; i++) { + cBytes.setChunk(0, ones); + } + bytes = aBytes.get(0); + } + int nuQuo = (nu / 8) % 8; + int nuRem = nu % 8; + byte pivotByte = bytes.get(7 - nuQuo); + + for (int i = 0; i < 8; i++) { + cBytes.set(1, i, pivotByte); + cBytes.set(2, i, boolToByte(i > 7 - nuRem)); + cBytes.set(3, i, boolToByte(i > 7 - nuQuo)); + hBytes.set(2, i, callFunc(i, 7 - nuRem)); + hBytes.set(3, i, callFunc(i, 7 - nuQuo)); + } + + bits = byteBits(UnsignedByte.of(pivotByte)); + + int lowerBoundOnTwoAdicity = 8 * (int) (hBytes.get(3, 7)) + (int) (hBytes.get(2, 7)); + + if (nu >= 64) { + lowerBoundOnTwoAdicity += 64; + } + + // our lower bound should coincide with the 2-adicity + if (lowerBoundOnTwoAdicity != nu) { + String s = + String.format( + "2-adicity nu = %d != %d = lower bound on 2-adicity", nu, lowerBoundOnTwoAdicity); + throw new RuntimeException(s); + } + if (lowerBoundOnTwoAdicity == 0) { + throw new RuntimeException("lower bound on 2 adicity == 0 in the zero result case"); + } + + final UInt256 twoFiftySix = UInt256.valueOf(256); + if (arg2.compareTo(twoFiftySix) >= 0) { + // arg2 = exponent >= 256 + hBytes.set(1, 6, (byte) ((lowerBoundOnTwoAdicity - 1) / 256)); + hBytes.set(1, 7, (byte) ((lowerBoundOnTwoAdicity - 1) % 256)); + } else { + // exponent < 256 + int exponent = arg2.toUnsignedBigInteger().intValue(); + int target = exponent * lowerBoundOnTwoAdicity - 256; + + if (target < 0) { + throw new RuntimeException("lower bound on 2-adicity is wrong"); + } + + if (target > 255 * (8 * 7 + 7 + 64)) { + throw new RuntimeException("something went awfully wrong ..."); + } + + final BaseTheta thing = BaseTheta.fromBytes32(UInt256.valueOf(target)); + hBytes.setChunk(1, thing.get(0)); + } + + return; + } + + public static byte callFunc(final int x, final int k) { + if (x < k) { + return 0; + } + return (byte) (x - k); + } + + public boolean exponentBit() { + return '1' == exponentBits.charAt(index); + } + + public boolean exponentSource() { + return this.index + 128 >= exponentBits.length(); + } + + public static int twoAdicity(final Bytes32 x) { + + if (x.isZero()) { + // panic("twoAdicity was called on zero") + return 256; + } + + String baseStringBase2 = x.toBigInteger().toString(2); + + for (int i = 0; i < baseStringBase2.length(); i++) { + int j = baseStringBase2.length() - i - 1; + char zeroAscii = '0'; + if (baseStringBase2.charAt(j) != zeroAscii) { + return i; + } + } + + return 0; + } + // + // private boolean largeExponent() { + // return exponentBits.length() > 128; + // } + + public enum Regime { + IOTA, + TRIVIAL_MUL, + NON_TRIVIAL_MUL, + EXPONENT_ZERO_RESULT, + EXPONENT_NON_ZERO_RESULT + } + + public boolean isOneLineInstruction() { + return tinyBase || tinyExponent; + } + + public Regime getRegime() { + + if (isOneLineInstruction()) return Regime.TRIVIAL_MUL; + + if (OpCode.MUL.equals(opCode)) { + return Regime.NON_TRIVIAL_MUL; + } + + if (OpCode.EXP.equals(opCode)) { + if (res.isZero()) { + return Regime.EXPONENT_ZERO_RESULT; + } else { + return Regime.EXPONENT_NON_ZERO_RESULT; + } + } + return Regime.IOTA; + } + + public static boolean isTiny(BigInteger arg) { + return arg.compareTo(BigInteger.valueOf(1)) <= 0; + } + + public boolean carryOn() { + + // first round is special + if (index == 0 && !snm) { + snm = true; + resAcc = UInt256.valueOf(1); // TODO assuming this is what SetOne() does + cBytes = BaseTheta.fromBytes32(arg1); + return true; + } + + if (snm == exponentBit()) { + hiToLoExponentBitAccumulatorReset(); + index++; + snm = false; + if (index == exponentBits.length()) { + return false; + } + } else { + snm = true; + } + return true; + } + + public int getBitNum() { + return bitNum(index, exponentBits.length()); + } + + private int bitNum(int i, int length) { + if (length <= 128) { + return i; + } else { + if (i + 128 < length) { + return i; + } else { + return i + 128 - length; + } + } + } + + public void update() { + + final BigInteger arg1BigInt = UInt256.fromBytes(arg1).toUnsignedBigInteger(); + if (!snm) { + // squaring + setHsAndBits(resAcc, resAcc); + expAcc = expAcc.add(expAcc); + resAcc = resAcc.multiply(resAcc); + } else { + // multiplying by base + setHsAndBits(UInt256.valueOf(arg1BigInt), resAcc); + expAcc = expAcc.add(UInt256.ONE); + resAcc = UInt256.valueOf(arg1BigInt).multiply(resAcc); + } + cBytes = BaseTheta.fromBytes32(resAcc); + } + + public void setHsAndBits(UInt256 a, UInt256 b) { + + setHsAndBitsFromBaseThetas(BaseTheta.fromBytes32(a), BaseTheta.fromBytes32(b)); + } + + public void setHsAndBitsFromBaseThetas(BaseTheta aBaseTheta, BaseTheta bBaseTheta) { + BaseTheta sumBaseTheta; + UInt256[] aBaseThetaInts = (UInt256[]) Array.newInstance(UInt256.class, 4); + UInt256[] bBaseThetaInts = (UInt256[]) Array.newInstance(UInt256.class, 4); + + for (int i = 0; i < 4; i++) { + aBaseThetaInts[i] = UInt256.fromBytes(aBaseTheta.get(i)); + bBaseThetaInts[i] = UInt256.fromBytes(bBaseTheta.get(i)); + } + + UInt256 sum, prod; + prod = aBaseThetaInts[1].multiply(bBaseThetaInts[0]); + sum = prod; // sum := a1 * b0 + prod = aBaseThetaInts[0].multiply(bBaseThetaInts[1]); + sum = sum.add(prod); // sum += a0 * b1 + + sumBaseTheta = BaseTheta.fromBytes32(sum); + hBytes.setChunk(0, sumBaseTheta.get(0)); + hBytes.setChunk(1, sumBaseTheta.get(1)); + long alpha = getOverflow(sum, 1, "alpha OOB"); + + sum = aBaseThetaInts[3].multiply(bBaseThetaInts[0]); // sum := a3 * b0 + prod = aBaseThetaInts[2].multiply(bBaseThetaInts[1]); + sum = sum.add(prod); // sum += a2 * b1 + prod = aBaseThetaInts[1].multiply(bBaseThetaInts[2]); + sum = sum.add(prod); // sum += a1 * b2 + prod = aBaseThetaInts[0].multiply(bBaseThetaInts[3]); + sum = sum.add(prod); // sum += a0 * b3 + + sumBaseTheta = BaseTheta.fromBytes32(sum); + hBytes.setChunk(2, sumBaseTheta.get(0)); + hBytes.setChunk(3, sumBaseTheta.get(1)); + long beta = getOverflow(sum, 3, "beta OOB"); + + prod = aBaseThetaInts[0].multiply(bBaseThetaInts[0]); + sum = prod; // sum := a0 * b0 + + prod = UInt256.fromBytes(hBytes.get(0)).shiftLeft(64); + sum = sum.add(prod); // sum += (h0 << 64) + + long eta = getOverflow(sum, 1, "eta OOB"); + + sum = UInt256.valueOf(eta); // sum := eta + sum = sum.add(UInt256.fromBytes(hBytes.get(1))); // sum += h1 + prod = UInt256.valueOf(alpha).shiftLeft(64); + sum = sum.add(prod); // sum += (alpha << 64) + prod = aBaseThetaInts[2].multiply(bBaseThetaInts[0]); + sum = sum.add(prod); // sum += a2 * b0 + prod = aBaseThetaInts[1].multiply(bBaseThetaInts[1]); + sum = sum.add(prod); // sum += a1 * b1 + prod = aBaseThetaInts[0].multiply(bBaseThetaInts[2]); + sum = sum.add(prod); // sum += a0 * b2 + sum = sum.add(UInt256.fromBytes(hBytes.get(2)).shiftLeft(64)); // sum += (h2 << 64) + + long mu = getOverflow(sum, 3, "mu OOB"); + + bits[0] = false; + bits[1] = false; + bits[2] = getBit(alpha, 0); + bits[3] = getBit(beta, 0); + bits[4] = getBit(beta, 1); + bits[5] = getBit(eta, 0); + bits[6] = getBit(mu, 0); + bits[7] = getBit(mu, 1); + return; + } + + // hiToLoExponentBitAccumulatorReset resets the exponent bit accumulator + // under the following conditions: + // - we are dealing with the high part of the exponent bits, i.e. md.exponentBit() = 0 + // - SQUARE_AND_MULTIPLY == EXPONENT_BIT + // - the exponent bit accumulator coincides with the high part of the exponent + private void hiToLoExponentBitAccumulatorReset() { + if (!exponentSource()) { + if (snm == exponentBit()) { // note: when called this is already assumed + Bytes32 arg2Copy = arg2.copy(); + if (arg2Copy.shiftRight(128).equals(expAcc)) { + expAcc = UInt256.MIN_VALUE; + } + } + } + } + + public int maxCt() { + return isOneLineInstruction() ? 1 : MMEDIUM; + } +} diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java new file mode 100644 index 0000000000..18b7cb1cc7 --- /dev/null +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java @@ -0,0 +1,520 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package net.consensys.linea.zktracer.module.alu.mul; + +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import net.consensys.linea.zktracer.bytes.UnsignedByte; + +@JsonPropertyOrder({"Trace", "Stamp"}) +@SuppressWarnings("unused") +public record MulTrace(@JsonProperty("Trace") Trace trace, @JsonProperty("Stamp") int stamp) { + @JsonPropertyOrder({ + "ACC_A_0", + "ACC_A_1", + "ACC_A_2", + "ACC_A_3", + "ACC_B_0", + "ACC_B_1", + "ACC_B_2", + "ACC_B_3", + "ACC_C_0", + "ACC_C_1", + "ACC_C_2", + "ACC_C_3", + "ACC_H_0", + "ACC_H_1", + "ACC_H_2", + "ACC_H_3", + "ARG_1_HI", + "ARG_1_LO", + "ARG_2_HI", + "ARG_2_LO", + "BITS", + "BIT_NUM", + "BYTE_A_0", + "BYTE_A_1", + "BYTE_A_2", + "BYTE_A_3", + "BYTE_B_0", + "BYTE_B_1", + "BYTE_B_2", + "BYTE_B_3", + "BYTE_C_0", + "BYTE_C_1", + "BYTE_C_2", + "BYTE_C_3", + "BYTE_H_0", + "BYTE_H_1", + "BYTE_H_2", + "BYTE_H_3", + "COUNTER", + "EXPONENT_BIT", + "EXPONENT_BIT_ACCUMULATOR", + "EXPONENT_BIT_SOURCE", + "INST", // INSTRUCTION + "MUL_STAMP", + "OLI", // "ONE_LINE_INSTRUCTION", + "RESULT_VANISHES", + "RES_HI", + "RES_LO", + "SQUARE_AND_MULTIPLY", + "TINY_BASE", + "TINY_EXPONENT", + }) + @SuppressWarnings("unused") + public record Trace( + @JsonProperty("ACC_A_0") List ACC_A_0, + @JsonProperty("ACC_A_1") List ACC_A_1, + @JsonProperty("ACC_A_2") List ACC_A_2, + @JsonProperty("ACC_A_3") List ACC_A_3, + @JsonProperty("ACC_B_0") List ACC_B_0, + @JsonProperty("ACC_B_1") List ACC_B_1, + @JsonProperty("ACC_B_2") List ACC_B_2, + @JsonProperty("ACC_B_3") List ACC_B_3, + @JsonProperty("ACC_C_0") List ACC_C_0, + @JsonProperty("ACC_C_1") List ACC_C_1, + @JsonProperty("ACC_C_2") List ACC_C_2, + @JsonProperty("ACC_C_3") List ACC_C_3, + @JsonProperty("ACC_H_0") List ACC_H_0, + @JsonProperty("ACC_H_1") List ACC_H_1, + @JsonProperty("ACC_H_2") List ACC_H_2, + @JsonProperty("ACC_H_3") List ACC_H_3, + @JsonProperty("ARG_1_HI") List ARG_1_HI, + @JsonProperty("ARG_1_LO") List ARG_1_LO, + @JsonProperty("ARG_2_HI") List ARG_2_HI, + @JsonProperty("ARG_2_LO") List ARG_2_LO, + @JsonProperty("BITS") List BITS, + @JsonProperty("BIT_NUM") List BIT_NUM, + @JsonProperty("BYTE_A_0") List BYTE_A_0, + @JsonProperty("BYTE_A_1") List BYTE_A_1, + @JsonProperty("BYTE_A_2") List BYTE_A_2, + @JsonProperty("BYTE_A_3") List BYTE_A_3, + @JsonProperty("BYTE_B_0") List BYTE_B_0, + @JsonProperty("BYTE_B_1") List BYTE_B_1, + @JsonProperty("BYTE_B_2") List BYTE_B_2, + @JsonProperty("BYTE_B_3") List BYTE_B_3, + @JsonProperty("BYTE_C_0") List BYTE_C_0, + @JsonProperty("BYTE_C_1") List BYTE_C_1, + @JsonProperty("BYTE_C_2") List BYTE_C_2, + @JsonProperty("BYTE_C_3") List BYTE_C_3, + @JsonProperty("BYTE_H_0") List BYTE_H_0, + @JsonProperty("BYTE_H_1") List BYTE_H_1, + @JsonProperty("BYTE_H_2") List BYTE_H_2, + @JsonProperty("BYTE_H_3") List BYTE_H_3, + @JsonProperty("COUNTER") List COUNTER, + @JsonProperty("EXPONENT_BIT") List EXPONENT_BIT, + @JsonProperty("EXPONENT_BIT_ACCUMULATOR") List EXPONENT_BIT_ACCUMULATOR, + @JsonProperty("EXPONENT_BIT_SOURCE") List EXPONENT_BIT_SOURCE, + @JsonProperty("INST") List INST, + @JsonProperty("MUL_STAMP") List MUL_STAMP, + @JsonProperty("OLI") List ONE_LINE_INSTRUCTION, + @JsonProperty("RESULT_VANISHES") List RESULT_VANISHES, + @JsonProperty("RES_HI") List RES_HI, + @JsonProperty("RES_LO") List RES_LO, + @JsonProperty("SQUARE_AND_MULTIPLY") List SQUARE_AND_MULTIPLY, + @JsonProperty("TINY_BASE") List TINY_BASE, + @JsonProperty("TINY_EXPONENT") List TINY_EXPONENT) { + + public static class Builder { + private final List mulStamp = new ArrayList<>(); + private final List counter = new ArrayList<>(); + private final List oneLineInstruction = new ArrayList<>(); + private final List tinyBase = new ArrayList<>(); + private final List tinyExponent = new ArrayList<>(); + private final List resultVanishes = new ArrayList<>(); + private final List inst = new ArrayList<>(); + private final List arg1Hi = new ArrayList<>(); + private final List arg1Lo = new ArrayList<>(); + private final List arg2Hi = new ArrayList<>(); + private final List arg2Lo = new ArrayList<>(); + private final List resHi = new ArrayList<>(); + private final List resLo = new ArrayList<>(); + private final List bits = new ArrayList<>(); + + private final List byteA3 = new ArrayList<>(); + private final List byteA2 = new ArrayList<>(); + private final List byteA1 = new ArrayList<>(); + private final List byteA0 = new ArrayList<>(); + private final List accA3 = new ArrayList<>(); + private final List accA2 = new ArrayList<>(); + private final List accA1 = new ArrayList<>(); + private final List accA0 = new ArrayList<>(); + + private final List byteB3 = new ArrayList<>(); + private final List byteB2 = new ArrayList<>(); + private final List byteB1 = new ArrayList<>(); + private final List byteB0 = new ArrayList<>(); + private final List accB3 = new ArrayList<>(); + private final List accB2 = new ArrayList<>(); + private final List accB1 = new ArrayList<>(); + private final List accB0 = new ArrayList<>(); + + private final List byteC3 = new ArrayList<>(); + private final List byteC2 = new ArrayList<>(); + private final List byteC1 = new ArrayList<>(); + private final List byteC0 = new ArrayList<>(); + private final List accC3 = new ArrayList<>(); + private final List accC2 = new ArrayList<>(); + private final List accC1 = new ArrayList<>(); + private final List accC0 = new ArrayList<>(); + + private final List byteH3 = new ArrayList<>(); + private final List byteH2 = new ArrayList<>(); + private final List byteH1 = new ArrayList<>(); + private final List byteH0 = new ArrayList<>(); + private final List accH3 = new ArrayList<>(); + private final List accH2 = new ArrayList<>(); + private final List accH1 = new ArrayList<>(); + private final List accH0 = new ArrayList<>(); + + private final List exponentBit = new ArrayList<>(); + private final List exponentBitAccumulator = new ArrayList<>(); + private final List exponentBitSource = new ArrayList<>(); + private final List squareAndMultiply = new ArrayList<>(); + private final List bitNum = new ArrayList<>(); + private int stamp = 0; + + private Builder() {} + + public static Builder newInstance() { + return new Builder(); + } + + public Builder appendAccA0(final BigInteger b) { + accA0.add(b); + return this; + } + + public Builder appendAccA1(final BigInteger b) { + accA1.add(b); + return this; + } + + public Builder appendAccA2(final BigInteger b) { + accA2.add(b); + return this; + } + + public Builder appendAccA3(final BigInteger b) { + accA3.add(b); + return this; + } + + public Builder appendAccB0(final BigInteger b) { + accB0.add(b); + return this; + } + + public Builder appendAccB1(final BigInteger b) { + accB1.add(b); + return this; + } + + public Builder appendAccB2(final BigInteger b) { + accB2.add(b); + return this; + } + + public Builder appendAccB3(final BigInteger b) { + accB3.add(b); + return this; + } + + public Builder appendAccC0(final BigInteger b) { + accC0.add(b); + return this; + } + + public Builder appendAccC1(final BigInteger b) { + accC1.add(b); + return this; + } + + public Builder appendAccC2(final BigInteger b) { + accC2.add(b); + return this; + } + + public Builder appendAccC3(final BigInteger b) { + accC3.add(b); + return this; + } + + public Builder appendAccH0(final BigInteger b) { + accH0.add(b); + return this; + } + + public Builder appendAccH1(final BigInteger b) { + accH1.add(b); + return this; + } + + public Builder appendAccH2(final BigInteger b) { + accH2.add(b); + return this; + } + + public Builder appendAccH3(final BigInteger b) { + accH3.add(b); + return this; + } + + public Builder appendArg1Hi(final BigInteger b) { + arg1Hi.add(b); + return this; + } + + public Builder appendArg1Lo(final BigInteger b) { + arg1Lo.add(b); + return this; + } + + public Builder appendArg2Hi(final BigInteger b) { + arg2Hi.add(b); + return this; + } + + public Builder appendArg2Lo(final BigInteger b) { + arg2Lo.add(b); + return this; + } + + public Builder appendBits(final Boolean b) { + bits.add(b); + return this; + } + + public Builder appendByteA0(final UnsignedByte b) { + byteA0.add(b); + return this; + } + + public Builder appendByteA1(final UnsignedByte b) { + byteA1.add(b); + return this; + } + + public Builder appendByteA2(final UnsignedByte b) { + byteA2.add(b); + return this; + } + + public Builder appendByteA3(final UnsignedByte b) { + byteA3.add(b); + return this; + } + + public Builder appendByteB0(final UnsignedByte b) { + byteB0.add(b); + return this; + } + + public Builder appendByteB1(final UnsignedByte b) { + byteB1.add(b); + return this; + } + + public Builder appendByteB2(final UnsignedByte b) { + byteB2.add(b); + return this; + } + + public Builder appendByteB3(final UnsignedByte b) { + byteB3.add(b); + return this; + } + + public Builder appendByteC0(final UnsignedByte b) { + byteC0.add(b); + return this; + } + + public Builder appendByteC1(final UnsignedByte b) { + byteC1.add(b); + return this; + } + + public Builder appendByteC2(final UnsignedByte b) { + byteC2.add(b); + return this; + } + + public Builder appendByteC3(final UnsignedByte b) { + byteC3.add(b); + return this; + } + + public Builder appendByteH0(final UnsignedByte b) { + byteH0.add(b); + return this; + } + + public Builder appendByteH1(final UnsignedByte b) { + byteH1.add(b); + return this; + } + + public Builder appendByteH2(final UnsignedByte b) { + byteH2.add(b); + return this; + } + + public Builder appendByteH3(final UnsignedByte b) { + byteH3.add(b); + return this; + } + + public Builder appendCounter(final Integer b) { + counter.add(b); + return this; + } + + public Builder appendInst(final UnsignedByte b) { + inst.add(b); + return this; + } + + public Builder appendOneLineInstruction(final Boolean b) { + oneLineInstruction.add(b); + return this; + } + + public Builder appendTinyBase(final Boolean b) { + tinyBase.add(b); + return this; + } + + public Builder appendTinyExponent(final Boolean b) { + tinyExponent.add(b); + return this; + } + + public Builder appendResultVanishes(final Boolean b) { + resultVanishes.add(b); + return this; + } + + public Builder appendResHi(final BigInteger b) { + resHi.add(b); + return this; + } + + public Builder appendResLo(final BigInteger b) { + resLo.add(b); + return this; + } + + // + + public Builder appendExponentBit(final Boolean b) { + exponentBit.add(b); + return this; + } + + public Builder appendExponentBitAcc(final BigInteger b) { + exponentBitAccumulator.add(b); + return this; + } + + public Builder appendExponentBitSource(final Boolean b) { + exponentBitSource.add(b); + return this; + } + + public Builder appendSquareAndMultiply(final Boolean b) { + squareAndMultiply.add(b); + return this; + } + + public Builder appendBitNum(final Integer b) { + bitNum.add(b); + return this; + } + + public Builder appendStamp(final Integer b) { + mulStamp.add(b); + return this; + } + + public Builder setStamp(final int stamp) { + this.stamp = stamp; + return this; + } + + public MulTrace build() { + return new MulTrace( + new Trace( + accA0, + accA1, + accA2, + accA3, + accB0, + accB1, + accB2, + accB3, + accC0, + accC1, + accC2, + accC3, + accH0, + accH1, + accH2, + accH3, + arg1Hi, + arg1Lo, + arg2Hi, + arg2Lo, + bits, + bitNum, + byteA0, + byteA1, + byteA2, + byteA3, + byteB0, + byteB1, + byteB2, + byteB3, + byteC0, + byteC1, + byteC2, + byteC3, + byteH0, + byteH1, + byteH2, + byteH3, + counter, + exponentBit, + exponentBitAccumulator, + exponentBitSource, + inst, + mulStamp, + oneLineInstruction, + resultVanishes, + resHi, + resLo, + squareAndMultiply, + tinyBase, + tinyExponent), + stamp); + } + } + } +} diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java new file mode 100644 index 0000000000..4ac38dc899 --- /dev/null +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java @@ -0,0 +1,161 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package net.consensys.linea.zktracer.module.alu.mul; + +import org.hyperledger.besu.evm.frame.MessageFrame; + +import java.util.List; + +import net.consensys.linea.zktracer.OpCode; +import net.consensys.linea.zktracer.bytes.UnsignedByte; +import net.consensys.linea.zktracer.module.ModuleTracer; +import org.apache.tuweni.bytes.Bytes32; +import org.apache.tuweni.units.bigints.UInt256; + +public class MulTracer implements ModuleTracer { + + private int stamp = 0; + + @Override + public String jsonKey() { + return "mul"; + } + + @Override + public List supportedOpCodes() { + return List.of(OpCode.MUL, OpCode.EXP); + } + + @SuppressWarnings("UnusedVariable") + @Override + public Object trace(MessageFrame frame) { + final Bytes32 arg1 = Bytes32.wrap(frame.getStackItem(0)); + final Bytes32 arg2 = Bytes32.wrap(frame.getStackItem(1)); + + final OpCode opCode = OpCode.of(frame.getCurrentOperation().getOpcode()); + + // argument order is reversed ?? + final MulData data = new MulData(opCode, arg2, arg1); + final MulTrace.Trace.Builder builder = MulTrace.Trace.Builder.newInstance(); + + switch (data.getRegime()) { + case EXPONENT_ZERO_RESULT -> { + trace(builder, data); + } + + case EXPONENT_NON_ZERO_RESULT -> { + if (data.carryOn()) { + data.update(); + trace(builder, data); + } + } + + case TRIVIAL_MUL, NON_TRIVIAL_MUL -> { + data.setHsAndBits(UInt256.fromBytes(arg1), UInt256.fromBytes(arg2)); + trace(builder, data); + } + + default -> throw new RuntimeException("regime not supported"); + } + // TODO captureBlockEnd should be called from elsewhere - not within messageFrame + // captureBlockEnd(); + MulData finalZeroToTheZero = new MulData(OpCode.EXP, Bytes32.ZERO, Bytes32.ZERO); + trace(builder, finalZeroToTheZero); + + return builder.build(); + } + + private void trace(final MulTrace.Trace.Builder builder, final MulData data) { + stamp++; + + for (int ct = 0; ct < data.maxCt(); ct++) { + trace(builder, data, ct); + } + } + + private void trace(final MulTrace.Trace.Builder builder, final MulData data, final int i) { + builder.appendStamp(stamp); + builder.appendCounter(i); + + builder + .appendOneLineInstruction(data.isOneLineInstruction()) + .appendTinyBase(data.tinyBase) + .appendTinyExponent(data.tinyExponent) + .appendResultVanishes(data.res.isZero()); + + builder + .appendInst(UnsignedByte.of(data.opCode.value)) + .appendArg1Hi(data.arg1Hi.toUnsignedBigInteger()) + .appendArg1Lo(data.arg1Lo.toUnsignedBigInteger()) + .appendArg2Hi(data.arg2Hi.toUnsignedBigInteger()) + .appendArg2Lo(data.arg2Lo.toUnsignedBigInteger()); + + builder + .appendResHi(data.res.getHigh().toUnsignedBigInteger()) + .appendResLo(data.res.getLow().toUnsignedBigInteger()); + + builder.appendBits(data.bits[i]); + + builder + .appendByteA3(UnsignedByte.of(data.aBytes.get(3, i))) + .appendByteA2(UnsignedByte.of(data.aBytes.get(2, i))) + .appendByteA1(UnsignedByte.of(data.aBytes.get(1, i))) + .appendByteA0(UnsignedByte.of(data.aBytes.get(0, i))); + builder + .appendAccA3(data.aBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) + .appendAccA2(data.aBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) + .appendAccA1(data.aBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) + .appendAccA0(data.aBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); + + builder + .appendByteB3(UnsignedByte.of(data.bBytes.get(3, i))) + .appendByteB2(UnsignedByte.of(data.bBytes.get(2, i))) + .appendByteB1(UnsignedByte.of(data.bBytes.get(1, i))) + .appendByteB0(UnsignedByte.of(data.bBytes.get(0, i))); + builder + .appendAccB3(data.bBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) + .appendAccB2(data.bBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) + .appendAccB1(data.bBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) + .appendAccB0(data.bBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); + builder + .appendByteC3(UnsignedByte.of(data.cBytes.get(3, i))) + .appendByteC2(UnsignedByte.of(data.cBytes.get(2, i))) + .appendByteC1(UnsignedByte.of(data.cBytes.get(1, i))) + .appendByteC0(UnsignedByte.of(data.cBytes.get(0, i))); + builder + .appendAccC3(data.cBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) + .appendAccC2(data.cBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) + .appendAccC1(data.cBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) + .appendAccC0(data.cBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); + + builder + .appendByteH3(UnsignedByte.of(data.hBytes.get(3, i))) + .appendByteH2(UnsignedByte.of(data.hBytes.get(2, i))) + .appendByteH1(UnsignedByte.of(data.hBytes.get(1, i))) + .appendByteH0(UnsignedByte.of(data.hBytes.get(0, i))); + builder + .appendAccH3(data.hBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) + .appendAccH2(data.hBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) + .appendAccH1(data.hBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) + .appendAccH0(data.hBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); + builder + .appendExponentBit(data.exponentBit()) + .appendExponentBitAcc(data.expAcc.toUnsignedBigInteger()) + .appendExponentBitSource(data.exponentSource()) + .appendSquareAndMultiply(data.snm) + .appendBitNum(data.getBitNum()); + builder.setStamp(stamp); + } +} diff --git a/src/test/java/net/consensys/linea/zktracer/AbstractModuleTracerCorsetTest.java b/src/test/java/net/consensys/linea/zktracer/AbstractModuleTracerCorsetTest.java index 3a1ae8bae0..913fd42d91 100644 --- a/src/test/java/net/consensys/linea/zktracer/AbstractModuleTracerCorsetTest.java +++ b/src/test/java/net/consensys/linea/zktracer/AbstractModuleTracerCorsetTest.java @@ -20,6 +20,7 @@ import java.util.stream.Stream; import org.apache.tuweni.bytes.Bytes32; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; import org.junit.jupiter.params.ParameterizedTest; @@ -27,6 +28,7 @@ import org.junit.jupiter.params.provider.MethodSource; @TestInstance(Lifecycle.PER_CLASS) +@Tag("CorsetTest") public abstract class AbstractModuleTracerCorsetTest extends AbstractBaseModuleTracerTest { static final Random rand = new Random(); private static final int TEST_REPETITIONS = 8; diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/wcp/WcpTracerTest.java b/src/test/java/net/consensys/linea/zktracer/corset/module/wcp/WcpTracerTest.java deleted file mode 100644 index 038bde857e..0000000000 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/wcp/WcpTracerTest.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Copyright ConsenSys AG. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - * - * SPDX-License-Identifier: Apache-2.0 - */ -package net.consensys.linea.zktracer.corset.module.wcp; - -import static net.consensys.linea.zktracer.OpCode.SGT; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.mockito.Mockito.when; - -import org.hyperledger.besu.evm.frame.MessageFrame; -import org.hyperledger.besu.evm.operation.Operation; - -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.stream.Stream; - -import net.consensys.linea.zktracer.OpCode; -import net.consensys.linea.zktracer.ZkTraceBuilder; -import net.consensys.linea.zktracer.ZkTracer; -import net.consensys.linea.zktracer.corset.CorsetValidator; -import net.consensys.linea.zktracer.module.wcp.WcpTracer; -import org.apache.tuweni.bytes.Bytes32; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -@ExtendWith(MockitoExtension.class) -class WcpTracerTest { - private ZkTracer zkTracer; - private ZkTraceBuilder zkTraceBuilder; - @Mock MessageFrame mockFrame; - @Mock Operation mockOperation; - - private static final Random rand = new Random(); - private static final int TEST_REPETITIONS = 4; - - @BeforeEach - void setUp() { - zkTraceBuilder = new ZkTraceBuilder(); - zkTracer = new ZkTracer(zkTraceBuilder, List.of(new WcpTracer())); - when(mockFrame.getCurrentOperation()).thenReturn(mockOperation); - } - - @ParameterizedTest() - @MethodSource("provideRandomArguments") - void testRandomWcp(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { - when(mockOperation.getOpcode()).thenReturn((int) opCode.value); - when(mockFrame.getStackItem(0)).thenReturn(arg1); - when(mockFrame.getStackItem(1)).thenReturn(arg2); - zkTracer.tracePreExecution(mockFrame); - assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); - } - - @Test - public void testNonRandomWcp() { - Bytes32 arg1 = - Bytes32.fromHexString("0xdcd5cf52e4daec5389587d0d0e996e6ce2d0546b63d3ea0a0dc48ad984d180a9"); - Bytes32 arg2 = - Bytes32.fromHexString("0x0479484af4a59464a48818b3980174687661bafb13d06f49537995fa6c02159e"); - traceOperation(SGT, arg1, arg2); - } - - public static Stream provideRandomArguments() { - final List arguments = new ArrayList<>(); - for (OpCode opCode : new WcpTracer().supportedOpCodes()) { - for (int i = 0; i <= TEST_REPETITIONS; i++) { - Bytes32[] payload = new Bytes32[2]; - payload[0] = Bytes32.random(rand); - payload[1] = Bytes32.random(rand); - arguments.add(Arguments.of(opCode, payload[0], payload[1])); - } - } - return arguments.stream(); - } - - private void traceOperation(OpCode opcode, Bytes32 arg1, Bytes32 arg2) { - when(mockOperation.getOpcode()).thenReturn((int) opcode.value); - when(mockFrame.getStackItem(0)).thenReturn(arg1); - when(mockFrame.getStackItem(1)).thenReturn(arg2); - zkTracer.tracePreExecution(mockFrame); - assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); - } -} diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AddTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AddTracerTest.java similarity index 96% rename from src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AddTracerTest.java rename to src/test/java/net/consensys/linea/zktracer/module/alu/add/AddTracerTest.java index dd4a384eb3..066b66eaf4 100644 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AddTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AddTracerTest.java @@ -12,7 +12,7 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.corset.module.alu.add; +package net.consensys.linea.zktracer.module.alu.add; import java.util.ArrayList; import java.util.List; @@ -22,7 +22,6 @@ import net.consensys.linea.zktracer.AbstractModuleTracerCorsetTest; import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.module.ModuleTracer; -import net.consensys.linea.zktracer.module.alu.add.AddTracer; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AdderTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AdderTest.java similarity index 95% rename from src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AdderTest.java rename to src/test/java/net/consensys/linea/zktracer/module/alu/add/AdderTest.java index 3b28987e43..590c74bb96 100644 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AdderTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AdderTest.java @@ -12,13 +12,12 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.corset.module.alu.add; +package net.consensys.linea.zktracer.module.alu.add; import static org.assertj.core.api.Assertions.assertThat; import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.bytestheta.BaseBytes; -import net.consensys.linea.zktracer.module.alu.add.Adder; import org.apache.tuweni.bytes.Bytes32; import org.junit.jupiter.api.Test; diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/mod/ModTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mod/ModTracerTest.java similarity index 97% rename from src/test/java/net/consensys/linea/zktracer/corset/module/alu/mod/ModTracerTest.java rename to src/test/java/net/consensys/linea/zktracer/module/alu/mod/ModTracerTest.java index 4c1a84a0f7..b046a13bf5 100644 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/mod/ModTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mod/ModTracerTest.java @@ -12,7 +12,7 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.corset.module.alu.mod; +package net.consensys.linea.zktracer.module.alu.mod; import java.math.BigInteger; import java.util.ArrayList; @@ -23,7 +23,6 @@ import net.consensys.linea.zktracer.AbstractModuleTracerCorsetTest; import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.module.ModuleTracer; -import net.consensys.linea.zktracer.module.alu.mod.ModTracer; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java new file mode 100644 index 0000000000..2a92031d83 --- /dev/null +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java @@ -0,0 +1,165 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package net.consensys.linea.zktracer.module.alu.mul; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.stream.Stream; + +import net.consensys.linea.zktracer.AbstractModuleTracerCorsetTest; +import net.consensys.linea.zktracer.OpCode; +import net.consensys.linea.zktracer.module.ModuleTracer; +import org.apache.tuweni.bytes.Bytes; +import org.apache.tuweni.bytes.Bytes32; +import org.apache.tuweni.units.bigints.UInt256; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class MulTracerTest extends AbstractModuleTracerCorsetTest { + private static final Random rand = new Random(); + + private static final int TEST_MUL_REPETITIONS = 16; + + @ParameterizedTest() + @MethodSource("provideRandomAluMulArguments") + void aluMulTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, List.of(arg1, arg2)); + } + + @ParameterizedTest() + @MethodSource("singleTinyExponentiation") + void testSingleTinyExponentiation(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, List.of(arg1, arg2)); + } + + @ParameterizedTest() + @MethodSource("provideTinyArguments") + void tinyArgsTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, List.of(arg1, arg2)); + } + + @ParameterizedTest() + @MethodSource("provideSpecificNonTinyArguments") + void nonTinyArgsTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, List.of(arg1, arg2)); + } + + @ParameterizedTest() + @MethodSource("provideRandomNonTinyArguments") + void randomNonTinyArgsTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, List.of(arg1, arg2)); + } + + @ParameterizedTest() + @MethodSource("multiplyByZero") + void zerosArgsTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, List.of(arg1, arg2)); + } + + public Stream singleTinyExponentiation() { + List arguments = new ArrayList<>(); + + Bytes32 bytes1 = Bytes32.leftPad(Bytes.fromHexString("0x13")); + Bytes32 bytes2 = Bytes32.leftPad(Bytes.fromHexString("0x02")); + arguments.add(Arguments.of(OpCode.EXP, bytes1, bytes2)); + return arguments.stream(); + } + + public Stream provideRandomAluMulArguments() { + List arguments = new ArrayList<>(); + + for (int i = 0; i < TEST_MUL_REPETITIONS; i++) { + arguments.add(getRandomAluMulInstruction(rand.nextInt(32) + 1, rand.nextInt(32) + 1)); + } + return arguments.stream(); + } + + private Arguments getRandomAluMulInstruction(int sizeArg1MinusOne, int sizeArg2MinusOne) { + Bytes32 bytes1 = UInt256.valueOf(sizeArg1MinusOne); + Bytes32 bytes2 = UInt256.valueOf(sizeArg2MinusOne); + OpCode opCode = getRandomSupportedOpcode(); + return Arguments.of(opCode, bytes1, bytes2); + } + + public Stream provideSpecificNonTinyArguments() { + // these values are used in Go module test + // 0x8a, 0x48, 0xaa, 0x20, 0xe2, 0x00, 0xce, 0x3f, 0xee, 0x16, 0xb5, 0xdc, 0xde, 0xc5, 0xc4, + // 0xfa, + // 0xff, 0x61, 0x3b, 0xc9, 0x14, 0xd4, 0x7c, 0xd6, 0xca, 0x69, 0x55, 0x3f, 0x8e, + // 0xb2, 0xb3, 0x77, + // byte(vm.PUSH32), + // 0x59, 0xb6, 0x35, 0xfe, 0xc8, 0x94, 0xca, 0xa3, 0xed, 0x68, 0x17, 0xb1, 0xe6, + // 0x7b, 0x3c, 0xba, + // 0xeb, 0x87, 0x57, 0xfd, 0x6c, 0x7b, 0x03, 0x11, 0x9b, 0x79, 0x53, 0x03, 0xb7, + // 0xcd, 0x72, 0xc1, + final Bytes32[] payload = new Bytes32[2]; + payload[0] = + Bytes32.fromHexString("0x8a48aa20e200ce3fee16b5dcdec5c4faff613bc914d47cd6ca69553f8eb2b377"); + payload[1] = + Bytes32.fromHexString("0x59b635fec894caa3ed6817b1e67b3cbaeb8757fd6c7b03119b795303b7cd72c1"); + return Stream.of(Arguments.of(OpCode.MUL, payload[0], payload[1])); + } + + public Stream provideRandomNonTinyArguments() { + List arguments = new ArrayList<>(); + + for (int i = 0; i < TEST_MUL_REPETITIONS; i++) { + arguments.add(getRandomAluMulInstruction(rand.nextInt(32) + 1, rand.nextInt(32) + 1)); + } + return arguments.stream(); + } + + public Stream provideTinyArguments() { + List arguments = new ArrayList<>(); + for (int i = 0; i < 4; i++) { + arguments.add(getRandomAluMulInstruction(i, i + 1)); + } + return arguments.stream(); + } + + @Override + protected Stream provideNonRandomArguments() { + List arguments = new ArrayList<>(); + for (OpCode opCode : getModuleTracer().supportedOpCodes()) { + for (int k = 0; k <= 3; k++) { + for (int i = 0; i <= 3; i++) { + arguments.add(Arguments.of(opCode, List.of(UInt256.valueOf(i), UInt256.valueOf(k)))); + } + } + } + return arguments.stream(); + } + + protected Stream multiplyByZero() { + List arguments = new ArrayList<>(); + for (int i = 0; i < 2; i++) { + Bytes32 bytes1 = Bytes32.ZERO; + Bytes32 bytes2 = UInt256.valueOf(i); + arguments.add(Arguments.of(OpCode.MUL, bytes1, bytes2)); + arguments.add(Arguments.of(OpCode.MUL, bytes2, bytes1)); + } + return arguments.stream(); + } + + @Override + protected ModuleTracer getModuleTracer() { + return new MulTracer(); + } +} diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java new file mode 100644 index 0000000000..2671110630 --- /dev/null +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java @@ -0,0 +1,239 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package net.consensys.linea.zktracer.module.alu.mul; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import java.math.BigInteger; + +import net.consensys.linea.zktracer.OpCode; +import net.consensys.linea.zktracer.bytes.Bytes16; +import net.consensys.linea.zktracer.bytes.UnsignedByte; +import net.consensys.linea.zktracer.bytestheta.BaseTheta; +import net.consensys.linea.zktracer.module.Util; +import org.apache.tuweni.bytes.Bytes; +import org.apache.tuweni.bytes.Bytes32; +import org.apache.tuweni.units.bigints.UInt256; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class MulUtilsTest { + @Test + public void isTiny() { + // tiny means zero or one + assertThat(MulData.isTiny(BigInteger.ZERO)).isTrue(); + assertThat(MulData.isTiny(BigInteger.ONE)).isTrue(); + assertThat(MulData.isTiny(BigInteger.TWO)).isFalse(); + assertThat(MulData.isTiny(BigInteger.TEN)).isFalse(); + } + + @Test + public void twoAdicity() { + assertThat(MulData.twoAdicity(UInt256.MIN_VALUE)).isEqualTo(256); + // TODO no idea what these should be + // assertThat(MulData.twoAdicity(UInt256.MAX_VALUE)).isEqualTo(0); + // assertThat(MulData.twoAdicity(UInt256.valueOf(1))).isEqualTo(0); + } + + @Test + public void multiplyByZero() { + Bytes32 arg1 = Bytes32.random(); + OpCode mul = OpCode.MUL; + MulData oxo = new MulData(mul, arg1, Bytes32.ZERO); + Assertions.assertThat(oxo.arg2Hi.isZero()).isTrue(); + Assertions.assertThat(oxo.arg2Lo).isEqualTo(Bytes16.ZERO); + Assertions.assertThat(oxo.arg2Hi).isEqualTo(Bytes16.ZERO); + assertThat(oxo.opCode).isEqualTo(mul); + assertThat(oxo.tinyExponent).isTrue(); + assertThat(oxo.isOneLineInstruction()).isTrue(); + assertThat(oxo.bits[0]).isFalse(); + } + + @Test + public void zeroExp() { + Bytes32 arg1 = Bytes32.random(); + OpCode mul = OpCode.EXP; + MulData oxo = new MulData(mul, arg1, Bytes32.ZERO); + Assertions.assertThat(oxo.arg2Hi.isZero()).isTrue(); + Assertions.assertThat(oxo.arg2Lo).isEqualTo(Bytes16.ZERO); + Assertions.assertThat(oxo.arg2Hi).isEqualTo(Bytes16.ZERO); + assertThat(oxo.opCode).isEqualTo(mul); + assertThat(oxo.tinyExponent).isTrue(); + assertThat(oxo.isOneLineInstruction()).isTrue(); + assertThat(oxo.bits[0]).isFalse(); + } + + @Test + public void testByteBits_ofZero() { + Boolean[] booleans = Util.byteBits(UnsignedByte.of(0)); + assertThat(booleans.length).isEqualTo(8); + assertThat(booleans[0]).isNotNull(); + } + + @Test + public void hBytesAllZeros() { + Bytes32 arg1 = Bytes32.ZERO; + Bytes32 arg2 = Bytes32.ZERO; + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas( + BaseTheta.fromBytes32(UInt256.ZERO), BaseTheta.fromBytes32(UInt256.ZERO)); + assertThat(mulData.hBytes.get(0).isZero()).isTrue(); + assertThat(mulData.hBytes.get(1).isZero()).isTrue(); + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + + assertThat(mulData.hBytes.get(0).shiftLeft(64)).isEqualTo(mulData.hBytes.get(1)); // ZERO + } + + @Test + public void hBytesWhereOneArgIsZero() { + Bytes32 arg1 = Bytes32.ZERO; + Bytes32 arg2 = Bytes32.ZERO; + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas( + BaseTheta.fromBytes32(UInt256.ZERO), BaseTheta.fromBytes32(UInt256.valueOf(1))); + assertThat(mulData.hBytes.get(0).isZero()).isTrue(); + assertThat(mulData.hBytes.get(1).isZero()).isTrue(); + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + + assertThat(mulData.hBytes.get(0).shiftLeft(64)).isEqualTo(mulData.hBytes.get(1)); // ZERO + } + + @Test + public void hBytes_5_5() { + Bytes32 arg1 = Bytes32.fromHexString("0x05"); + Bytes32 arg2 = Bytes32.fromHexString("0x05"); + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas(BaseTheta.fromBytes32(arg1), BaseTheta.fromBytes32(arg2)); + assertThat(mulData.hBytes.get(0).isZero()).isTrue(); + assertThat(mulData.hBytes.get(1).isZero()).isTrue(); + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + } + + @Test + public void hBytes_largeEnoughArgsToGetNonZeros() { + + // these args aren't used directly in the calculation of hs and bits + Bytes32 arg1 = Bytes32.fromHexString("0x05"); + Bytes32 arg2 = Bytes32.fromHexString("0x05"); + + BaseTheta aBaseTheta = BaseTheta.fromBytes32(UInt256.valueOf(43532)); // aa 0c // 170 12 + + // squaring is one way to get a sufficiently big number + BigInteger b = BigInteger.valueOf(82494664664768L); + BigInteger b2 = b.multiply(b); // 6805369698152522037820493824 + UInt256 b2uint = UInt256.valueOf(b2); + + BaseTheta bBaseTheta = BaseTheta.fromBytes32(b2uint); + + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas(aBaseTheta, bBaseTheta); + + Bytes h0 = Bytes.fromHexString("0x00000e9b37bfd908"); + + assertThat(mulData.hBytes.get(0)).isEqualTo(h0); + assertThat(mulData.hBytes.get(1).isZero()).isTrue(); + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + } + + @Test + public void hBytes_aReallyBigNumber() { + + // these args aren't used directly in the calculation of hs and bits + Bytes32 arg1 = Bytes32.fromHexString("0x05"); + Bytes32 arg2 = Bytes32.fromHexString("0x05"); + + BigInteger b = new BigInteger("296251353699975589350401737146368"); + UInt256 b4uint = UInt256.valueOf(b); + + BaseTheta aBaseTheta = BaseTheta.fromBytes32(b4uint); + BaseTheta bBaseTheta = BaseTheta.fromBytes32(b4uint); + + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas(aBaseTheta, bBaseTheta); + + Bytes h0 = Bytes.fromHexString("0xe9e4064d86460000"); + Bytes h1 = Bytes.fromHexString("0x00000779079ae9e3"); + + assertThat(mulData.hBytes.get(0)).isEqualTo(h0); + assertThat(mulData.hBytes.get(1)).isEqualTo(h1); + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + } + + @Test + public void hBytesValue_and_generatesATrueBit() { + + // these args aren't used directly in the calculation of hs and bits + Bytes32 arg1 = Bytes32.fromHexString("0x05"); + Bytes32 arg2 = Bytes32.fromHexString("0x05"); + + BigInteger b = new BigInteger("3672491014949214151879080813"); + UInt256 b4uint = UInt256.valueOf(b); + + BaseTheta aBaseTheta = BaseTheta.fromBytes32(b4uint); + BaseTheta bBaseTheta = BaseTheta.fromBytes32(b4uint); + + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas(aBaseTheta, bBaseTheta); + + Bytes h0 = Bytes.fromHexString("0x8d222c056a351fb0"); + Bytes h1 = Bytes.fromHexString("0x0000000013fc1cf0"); + + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + assertThat(mulData.hBytes.get(0)).isEqualTo(h0); + assertThat(mulData.hBytes.get(1)).isEqualTo(h1); + + // bits + // expected value obtained from go implementation debug output + Boolean[] expectedBools = {false, false, false, false, false, true, false, false}; + assertThat(mulData.bits).isEqualTo(expectedBools); + } + + @Test + public void hBytes_twoReallyBigNumbers_generatesADifferentTrueBit() { + + // these args aren't used directly in the calculation of hs and bits + Bytes32 arg1 = Bytes32.fromHexString("0x05"); + Bytes32 arg2 = Bytes32.fromHexString("0x05"); + + BigInteger b1 = + new BigInteger( + "22469423347992668196557015132986860313508181747976369840918221307635594854917"); + UInt256 b1uint = UInt256.valueOf(b1); + BigInteger b2 = + new BigInteger( + "24978870742348927442211038043709369780953629326985996012322553405261539030097"); + UInt256 b2uint = UInt256.valueOf(b2); + + BaseTheta aBaseTheta = BaseTheta.fromBytes32(b1uint); + BaseTheta bBaseTheta = BaseTheta.fromBytes32(b2uint); + + final MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas(aBaseTheta, bBaseTheta); + + BigInteger sum_010 = new BigInteger("375860551383434850958895718584879559103"); + assertThat(Util.getOverflow(UInt256.valueOf(sum_010), 3, "mu OOB")).isEqualTo(1); + + // bits + // expected value obtained from go implementation debug output + Boolean[] expectedBools = {false, false, false, false, false, false, true, false}; + assertThat(mulData.bits).isEqualTo(expectedBools); + } +} diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/shf/ShfTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/shf/ShfTracerTest.java similarity index 97% rename from src/test/java/net/consensys/linea/zktracer/corset/module/shf/ShfTracerTest.java rename to src/test/java/net/consensys/linea/zktracer/module/shf/ShfTracerTest.java index ac0bcd7bc5..5b24a645b0 100644 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/shf/ShfTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/shf/ShfTracerTest.java @@ -12,7 +12,7 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.corset.module.shf; +package net.consensys.linea.zktracer.module.shf; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.mockito.Mockito.when; @@ -28,11 +28,11 @@ import net.consensys.linea.zktracer.ZkTraceBuilder; import net.consensys.linea.zktracer.ZkTracer; import net.consensys.linea.zktracer.corset.CorsetValidator; -import net.consensys.linea.zktracer.module.shf.ShfTracer; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Named; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; @@ -44,6 +44,7 @@ import org.slf4j.LoggerFactory; @ExtendWith(MockitoExtension.class) +@Tag("CorsetTest") class ShfTracerTest { private static final Logger LOG = LoggerFactory.getLogger(ShfTracerTest.class); diff --git a/src/test/java/net/consensys/linea/zktracer/module/wcp/WcpTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/wcp/WcpTracerTest.java new file mode 100644 index 0000000000..3a3e78c1aa --- /dev/null +++ b/src/test/java/net/consensys/linea/zktracer/module/wcp/WcpTracerTest.java @@ -0,0 +1,45 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package net.consensys.linea.zktracer.module.wcp; + +import static net.consensys.linea.zktracer.OpCode.GT; + +import java.util.List; +import java.util.stream.Stream; + +import net.consensys.linea.zktracer.AbstractModuleTracerCorsetTest; +import net.consensys.linea.zktracer.module.ModuleTracer; +import org.apache.tuweni.bytes.Bytes32; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.provider.Arguments; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class WcpTracerTest extends AbstractModuleTracerCorsetTest { + + @Override + protected ModuleTracer getModuleTracer() { + return new WcpTracer(); + } + + @Override + protected Stream provideNonRandomArguments() { + Bytes32 arg1 = + Bytes32.fromHexString("0xdcd5cf52e4daec5389587d0d0e996e6ce2d0546b63d3ea0a0dc48ad984d180a9"); + Bytes32 arg2 = + Bytes32.fromHexString("0x0479484af4a59464a48818b3980174687661bafb13d06f49537995fa6c02159e"); + return Stream.of(Arguments.of(GT, List.of(arg1, arg2))); + } +}