From 9c31bac1818a9c0ea859cc22f6d73065ce349c90 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 29 Mar 2023 06:50:53 +1000 Subject: [PATCH 01/31] json object for mul Signed-off-by: Sally MacFarlane --- .../zktracer/module/alu/mul/MulTrace.java | 506 ++++++++++++++++++ 1 file changed, 506 insertions(+) create mode 100644 src/main/java/net/consensys/zktracer/module/alu/mul/MulTrace.java diff --git a/src/main/java/net/consensys/zktracer/module/alu/mul/MulTrace.java b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTrace.java new file mode 100644 index 0000000000..74b3469d1e --- /dev/null +++ b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTrace.java @@ -0,0 +1,506 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package net.consensys.zktracer.module.alu.mul; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import net.consensys.zktracer.bytes.UnsignedByte; + +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.List; + +@JsonPropertyOrder({"Trace", "Stamp"}) +@SuppressWarnings("unused") +public record MulTrace(@JsonProperty("Trace") Trace trace, @JsonProperty("Stamp") int stamp) { + @JsonPropertyOrder({ + "MUL_STAMP", + "COUNTER", + "OLI", // "ONE_LINE_INSTRUCTION", + "TINY_BASE", + "TINY_EXPONENT", + "RESULT_VANISHES", + "INST", // INSTRUCTION + "ARG_1_HI", + "ARG_1_LO", + "ARG_2_HI", + "ARG_2_LO", + "RES_HI", + "RES_LO", + "BITS", + "BYTE_A_3", + "BYTE_A_2", + "BYTE_A_1", + "BYTE_A_0", + + "ACC_A_A_3", + "ACC_A_2", + "ACC_A_1", + "ACC_A_0", + + "BYTE_B_3", + "BYTE_B_2", + "BYTE_B_1", + "BYTE_B_0", + + "ACC_B_3", + "ACC_B_2", + "ACC_B_1", + "ACC_B_0", + + "BYTE_C_3", + "BYTE_C_2", + "BYTE_C_1", + "BYTE_C_0", + + "ACC_C_3", + "ACC_C_2", + "ACC_C_1", + "ACC_C_0", + + "BYTE_H_3", + "BYTE_H_2", + "BYTE_H_1", + "BYTE_H_0", + + "ACC_H_3", + "ACC_H_2", + "ACC_H_1", + "ACC_H_0", + + + "EXPONENT_BIT", + "EXPONENT_BIT_ACCUMULATOR", + "EXPONENT_BIT_SOURCE", + "SQUARE_AND_MULTIPLY", + "BIT_NUM", + }) + @SuppressWarnings("unused") + public record Trace( + @JsonProperty("MUL_STAMP") List MUL_STAMP, + @JsonProperty("COUNTER") List COUNTER, + @JsonProperty("ONE_LINE_INSTRUCTION") List ONE_LINE_INSTRUCTION, + @JsonProperty("TINY_BASE") List TINY_BASE, + @JsonProperty("TINY_EXPONENT") List TINY_EXPONENT, + @JsonProperty("RESULT_VANISHES") List RESULT_VANISHES, + + @JsonProperty("INST") List INST, + @JsonProperty("ARG_1_HI") List ARG_1_HI, + @JsonProperty("ARG_1_LO") List ARG_1_LO, + @JsonProperty("ARG_2_HI") List ARG_2_HI, + @JsonProperty("ARG_2_LO") List ARG_2_LO, + @JsonProperty("RES_HI") List RES_HI, + @JsonProperty("RES_LO") List RES_LO, + @JsonProperty("BITS") List BITS, + @JsonProperty("BYTE_A_3") List BYTE_A_3, + @JsonProperty("BYTE_A_2") List BYTE_A_2, + @JsonProperty("BYTE_A_1") List BYTE_A_1, + @JsonProperty("BYTE_A_0") List BYTE_A_0, + @JsonProperty("ACC_A_3") List ACC_A_3, + @JsonProperty("ACC_A_2") List ACC_A_2, + @JsonProperty("ACC_A_1") List ACC_A_1, + @JsonProperty("ACC_A_4") List ACC_A_0, + @JsonProperty("BYTE_B_3") List BYTE_B_3, + @JsonProperty("BYTE_B_2") List BYTE_B_2, + @JsonProperty("BYTE_B_1") List BYTE_B_1, + @JsonProperty("BYTE_B_0") List BYTE_B_0, + @JsonProperty("ACC_B_3") List ACC_B_3, + @JsonProperty("ACC_B_2") List ACC_B_2, + @JsonProperty("ACC_B_1") List ACC_B_1, + @JsonProperty("ACC_B_4") List ACC_B_0, + @JsonProperty("BYTE_C_3") List BYTE_C_3, + @JsonProperty("BYTE_C_2") List BYTE_C_2, + @JsonProperty("BYTE_C_1") List BYTE_C_1, + @JsonProperty("BYTE_C_0") List BYTE_C_0, + @JsonProperty("ACC_C_3") List ACC_C_3, + @JsonProperty("ACC_C_2") List ACC_C_2, + @JsonProperty("ACC_C_1") List ACC_C_1, + @JsonProperty("ACC_C_4") List ACC_C_0, + @JsonProperty("BYTE_H_3") List BYTE_H_3, + @JsonProperty("BYTE_H_2") List BYTE_H_2, + @JsonProperty("BYTE_H_1") List BYTE_H_1, + @JsonProperty("BYTE_H_0") List BYTE_H_0, + @JsonProperty("ACC_H_3") List ACC_H_3, + @JsonProperty("ACC_H_2") List ACC_H_2, + @JsonProperty("ACC_H_1") List ACC_H_1, + @JsonProperty("ACC_H_4") List ACC_H_0, + @JsonProperty("EXPONENT_BIT") List EXPONENT_BIT, + @JsonProperty("EXPONENT_BIT_ACCUMULATOR") List EXPONENT_BIT_ACCUMULATOR, + @JsonProperty("EXPONENT_BIT_SOURCE") List EXPONENT_BIT_SOURCE, + @JsonProperty("SQUARE_AND_MULTIPLY") List SQUARE_AND_MULTIPLY, + @JsonProperty("BIT_NUM") List BIT_NUM) { + + public static class Builder { + private final List shiftStamp = new ArrayList<>(); + private final List counter = new ArrayList<>(); + private final List oneLineInstruction = new ArrayList<>(); + private final List tinyBase = new ArrayList<>(); + private final List tinyExponent = new ArrayList<>(); + private final List resultVanishes = new ArrayList<>(); + private final List inst = new ArrayList<>(); + private final List arg1Hi = new ArrayList<>(); + private final List arg1Lo = new ArrayList<>(); + private final List arg2Hi = new ArrayList<>(); + private final List arg2Lo = new ArrayList<>(); + private final List resHi = new ArrayList<>(); + private final List resLo = new ArrayList<>(); + private final List bits = new ArrayList<>(); + + private final List byteA3 = new ArrayList<>(); + private final List byteA2 = new ArrayList<>(); + private final List byteA1 = new ArrayList<>(); + private final List byteA0 = new ArrayList<>(); + private final List accA3 = new ArrayList<>(); + private final List accA2 = new ArrayList<>(); + private final List accA1 = new ArrayList<>(); + private final List accA0 = new ArrayList<>(); + + private final List byteB3 = new ArrayList<>(); + private final List byteB2 = new ArrayList<>(); + private final List byteB1 = new ArrayList<>(); + private final List byteB0 = new ArrayList<>(); + private final List accB3 = new ArrayList<>(); + private final List accB2 = new ArrayList<>(); + private final List accB1 = new ArrayList<>(); + private final List accB0 = new ArrayList<>(); + + private final List byteC3 = new ArrayList<>(); + private final List byteC2 = new ArrayList<>(); + private final List byteC1 = new ArrayList<>(); + private final List byteC0 = new ArrayList<>(); + private final List accC3 = new ArrayList<>(); + private final List accC2 = new ArrayList<>(); + private final List accC1 = new ArrayList<>(); + private final List accC0 = new ArrayList<>(); + + private final List byteH3 = new ArrayList<>(); + private final List byteH2 = new ArrayList<>(); + private final List byteH1 = new ArrayList<>(); + private final List byteH0 = new ArrayList<>(); + private final List accH3 = new ArrayList<>(); + private final List accH2 = new ArrayList<>(); + private final List accH1 = new ArrayList<>(); + private final List accH0 = new ArrayList<>(); + + private final List exponentBit = new ArrayList<>(); + // mul.Trace.PushLoBytes(EXPONENT_BIT_ACCUMULATOR.Name(), md.expAcc) // TODO: true ? risky :D + private final List exponentBitAccumulator = new ArrayList<>(); // PushLoBytes + private final List exponentBitSource = new ArrayList<>(); + private final List squareAndMultiply = new ArrayList<>(); + private final List bitNum = new ArrayList<>(); + private int stamp = 0; + + private Builder() {} + + public static Builder newInstance() { + return new Builder(); + } + + + public Builder appendAccA0(final BigInteger b) { + accA0.add(b); + return this; + } + public Builder appendAccA1(final BigInteger b) { + accA1.add(b); + return this; + } + + public Builder appendAccA2(final BigInteger b) { + accA2.add(b); + return this; + } + + public Builder appendAccA3(final BigInteger b) { + accA3.add(b); + return this; + } + + public Builder appendAccB0(final BigInteger b) { + accB0.add(b); + return this; + } + public Builder appendAccB1(final BigInteger b) { + accB1.add(b); + return this; + } + + public Builder appendAccB2(final BigInteger b) { + accB2.add(b); + return this; + } + + public Builder appendAccB3(final BigInteger b) { + accB3.add(b); + return this; + } + + public Builder appendAccC0(final BigInteger b) { + accC0.add(b); + return this; + } + public Builder appendAccC1(final BigInteger b) { + accC1.add(b); + return this; + } + + public Builder appendAccC2(final BigInteger b) { + accC2.add(b); + return this; + } + + public Builder appendAccC3(final BigInteger b) { + accC3.add(b); + return this; + } + public Builder appendAccH0(final BigInteger b) { + accH0.add(b); + return this; + } + public Builder appendAccH1(final BigInteger b) { + accH1.add(b); + return this; + } + + public Builder appendAccH2(final BigInteger b) { + accH2.add(b); + return this; + } + + public Builder appendAccH3(final BigInteger b) { + accH3.add(b); + return this; + } + + + + public Builder appendArg1Hi(final BigInteger b) { + arg1Hi.add(b); + return this; + } + + public Builder appendArg1Lo(final BigInteger b) { + arg1Lo.add(b); + return this; + } + + public Builder appendArg2Hi(final BigInteger b) { + arg2Hi.add(b); + return this; + } + + public Builder appendArg2Lo(final BigInteger b) { + arg2Lo.add(b); + return this; + } + + public Builder appendBits(final Boolean b) { + bits.add(b); + return this; + } + + public Builder appendByteA0(final UnsignedByte b) { + byteA0.add(b); + return this; + } + + public Builder appendByteA1(final UnsignedByte b) { + byteA1.add(b); + return this; + } + + public Builder appendByteA2(final UnsignedByte b) { + byteA2.add(b); + return this; + } + + public Builder appendByteA3(final UnsignedByte b) { + byteA3.add(b); + return this; + } + + public Builder appendByteB0(final UnsignedByte b) { + byteB0.add(b); + return this; + } + + public Builder appendByteB1(final UnsignedByte b) { + byteB1.add(b); + return this; + } + + public Builder appendByteB2(final UnsignedByte b) { + byteB2.add(b); + return this; + } + + public Builder appendByteB3(final UnsignedByte b) { + byteB3.add(b); + return this; + } + public Builder appendByteC0(final UnsignedByte b) { + byteC0.add(b); + return this; + } + + public Builder appendByteC1(final UnsignedByte b) { + byteC1.add(b); + return this; + } + + public Builder appendByteC2(final UnsignedByte b) { + byteC2.add(b); + return this; + } + + public Builder appendByteC3(final UnsignedByte b) { + byteC3.add(b); + return this; + } + public Builder appendByteH0(final UnsignedByte b) { + byteH0.add(b); + return this; + } + + public Builder appendByteH1(final UnsignedByte b) { + byteH1.add(b); + return this; + } + + public Builder appendByteH2(final UnsignedByte b) { + byteH2.add(b); + return this; + } + + public Builder appendByteH3(final UnsignedByte b) { + byteH3.add(b); + return this; + } + + public Builder appendCounter(final Integer b) { + counter.add(b); + return this; + } + + public Builder appendInst(final UnsignedByte b) { + inst.add(b); + return this; + } + + public Builder appendOneLineInstruction(final Boolean b) { + oneLineInstruction.add(b); + return this; + } + + public Builder appendTinyBase(final Boolean b) { + tinyBase.add(b); + return this; + } + + public Builder appendTinyExponent(final Boolean b) { + tinyExponent.add(b); + return this; + } + + public Builder appendResultVanishes(final Boolean b) { + resultVanishes.add(b); + return this; + } + + public Builder appendResHi(final BigInteger b) { + resHi.add(b); + return this; + } + + public Builder appendResLo(final BigInteger b) { + resLo.add(b); + return this; + } + + public Builder appendShiftStamp(final Integer b) { + shiftStamp.add(b); + return this; + } + + public Builder setStamp(final int stamp) { + this.stamp = stamp; + return this; + } + + public MulTrace build() { + return new MulTrace( + new Trace( + shiftStamp, + counter, + oneLineInstruction, + tinyBase, + tinyExponent, + resultVanishes, + inst, + arg1Hi, + arg1Lo, + arg2Hi, + arg2Lo, + resHi, + resLo, + bits, + + byteA3, + byteA2, + byteA1, + byteA0, + accA3, + accA2, + accA1, + accA0, + + byteB3, + byteB2, + byteB1, + byteB0, + accB3, + accB2, + accB1, + accB0, + + byteC3, + byteC2, + byteC1, + byteC0, + accC3, + accC2, + accC1, + accC0, + + byteH3, + byteH2, + byteH1, + byteH0, + accH3, + accH2, + accH1, + accH0, + + exponentBit, + exponentBitAccumulator, + exponentBitSource, + squareAndMultiply, + bitNum + ), + stamp); + } + } + } +} From cf3f002bc3104e247a53ecc7e94c7e0763cbd3cb Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 29 Mar 2023 15:58:30 +1000 Subject: [PATCH 02/31] WIP with passing tests Signed-off-by: Sally MacFarlane --- .../java/net/consensys/zktracer/OpCode.java | 3 + .../java/net/consensys/zktracer/ZkTracer.java | 3 +- .../zktracer/bytes/BytesBaseTheta.java | 62 ++++ .../zktracer/module/alu/mul/MulTrace.java | 271 +++++++++--------- .../zktracer/module/alu/mul/MulTracer.java | 167 +++++++++++ .../zktracer/module/alu/mul/Muler.java | 32 +++ .../zktracer/module/alu/mul/Res.java | 53 ++++ .../zktracer/module/shf/Shifter.java | 1 + .../module/alu/mul/MulTracerTest.java | 174 +++++++++++ .../zktracer/module/alu/mul/MulUtilsTest.java | 17 ++ 10 files changed, 654 insertions(+), 129 deletions(-) create mode 100644 src/main/java/net/consensys/zktracer/bytes/BytesBaseTheta.java create mode 100644 src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java create mode 100644 src/main/java/net/consensys/zktracer/module/alu/mul/Muler.java create mode 100644 src/main/java/net/consensys/zktracer/module/alu/mul/Res.java create mode 100644 src/test/java/net/consensys/zktracer/module/alu/mul/MulTracerTest.java create mode 100644 src/test/java/net/consensys/zktracer/module/alu/mul/MulUtilsTest.java diff --git a/src/main/java/net/consensys/zktracer/OpCode.java b/src/main/java/net/consensys/zktracer/OpCode.java index 07755c82e9..5cb0cdbb0b 100644 --- a/src/main/java/net/consensys/zktracer/OpCode.java +++ b/src/main/java/net/consensys/zktracer/OpCode.java @@ -19,6 +19,9 @@ import java.util.Map; public enum OpCode { + // mul + MUL(0x02), + EXP(0x0a), // shf SAR(0x1d), SHL(0x1b), diff --git a/src/main/java/net/consensys/zktracer/ZkTracer.java b/src/main/java/net/consensys/zktracer/ZkTracer.java index 5c1b856e7a..9c4846a54f 100644 --- a/src/main/java/net/consensys/zktracer/ZkTracer.java +++ b/src/main/java/net/consensys/zktracer/ZkTracer.java @@ -20,12 +20,13 @@ import net.consensys.zktracer.OpCode; import net.consensys.zktracer.ZkTraceBuilder; import net.consensys.zktracer.module.ModuleTracer; +import net.consensys.zktracer.module.alu.mul.MulTracer; import net.consensys.zktracer.module.shf.ShfTracer; import org.hyperledger.besu.evm.frame.MessageFrame; import org.hyperledger.besu.evm.tracing.OperationTracer; public class ZkTracer implements OperationTracer { - private final List tracers = List.of(new ShfTracer()); + private final List tracers = List.of(new MulTracer(), new ShfTracer()); private final Map> opCodeTracerMap = new HashMap<>(); private final ZkTraceBuilder zkTraceBuilder; diff --git a/src/main/java/net/consensys/zktracer/bytes/BytesBaseTheta.java b/src/main/java/net/consensys/zktracer/bytes/BytesBaseTheta.java new file mode 100644 index 0000000000..bf467ebf99 --- /dev/null +++ b/src/main/java/net/consensys/zktracer/bytes/BytesBaseTheta.java @@ -0,0 +1,62 @@ +package net.consensys.zktracer.bytes; + +import org.apache.tuweni.bytes.Bytes32; + +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; + +public class BytesBaseTheta { + + private byte[][] bytes; + + public BytesBaseTheta(final Bytes32 arg) { + bytes = new byte[4][8]; + byte[] argBytes = arg.toArray(); + + for (int k = 0; k < 4; k++) { + System.arraycopy(argBytes, 8 * k, bytes[3 - k], 0, 8); + } + } + + public Pair getHiLo() { + byte[] hiBytes = new byte[16]; + byte[] loBytes = new byte[16]; + + System.arraycopy(bytes[3], 0, hiBytes, 0, 8); + System.arraycopy(bytes[2], 0, hiBytes, 8, 8); + + System.arraycopy(bytes[1], 0, loBytes, 0, 8); + System.arraycopy(bytes[0], 0, loBytes, 8, 8); + + return new Pair<>(hiBytes, loBytes); + } + + public byte get(final int i, final int j) { + return bytes[i][j]; + } + public byte[] getRange(final int i, final int start, final int end) { + return Arrays.copyOfRange(bytes[i], start, end); + } +} + +@SuppressWarnings("UnusedVariable") +record Pair(A first, B second) { +} + + class UInt256 { + private byte[] bytes; + + public UInt256(byte[] bytes) { + this.bytes = bytes; + } + + public byte[] getBytes32() { + ByteBuffer buf = ByteBuffer.allocate(32); + buf.order(ByteOrder.BIG_ENDIAN); + buf.put(bytes); + return buf.array(); + } + } + diff --git a/src/main/java/net/consensys/zktracer/module/alu/mul/MulTrace.java b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTrace.java index 74b3469d1e..61a95f209f 100644 --- a/src/main/java/net/consensys/zktracer/module/alu/mul/MulTrace.java +++ b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTrace.java @@ -26,124 +26,136 @@ @SuppressWarnings("unused") public record MulTrace(@JsonProperty("Trace") Trace trace, @JsonProperty("Stamp") int stamp) { @JsonPropertyOrder({ - "MUL_STAMP", - "COUNTER", - "OLI", // "ONE_LINE_INSTRUCTION", - "TINY_BASE", - "TINY_EXPONENT", - "RESULT_VANISHES", - "INST", // INSTRUCTION + "ACC_A_0", + "ACC_A_1", + "ACC_A_2", + "ACC_A_3", + + "ACC_B_0", + "ACC_B_1", + "ACC_B_2", + "ACC_B_3", + + "ACC_C_0", + "ACC_C_1", + "ACC_C_2", + "ACC_C_3", + + "ACC_H_0", + "ACC_H_1", + "ACC_H_2", + "ACC_H_3", + "ARG_1_HI", "ARG_1_LO", "ARG_2_HI", "ARG_2_LO", - "RES_HI", - "RES_LO", - "BITS", - "BYTE_A_3", - "BYTE_A_2", - "BYTE_A_1", - "BYTE_A_0", - "ACC_A_A_3", - "ACC_A_2", - "ACC_A_1", - "ACC_A_0", + "BITS", + "BIT_NUM", - "BYTE_B_3", - "BYTE_B_2", - "BYTE_B_1", - "BYTE_B_0", + "BYTE_A_0", + "BYTE_A_1", + "BYTE_A_2", + "BYTE_A_3", - "ACC_B_3", - "ACC_B_2", - "ACC_B_1", - "ACC_B_0", + "BYTE_B_0", + "BYTE_B_1", + "BYTE_B_2", + "BYTE_B_3", - "BYTE_C_3", - "BYTE_C_2", - "BYTE_C_1", "BYTE_C_0", + "BYTE_C_1", + "BYTE_C_2", + "BYTE_C_3", - "ACC_C_3", - "ACC_C_2", - "ACC_C_1", - "ACC_C_0", - - "BYTE_H_3", - "BYTE_H_2", - "BYTE_H_1", "BYTE_H_0", + "BYTE_H_1", + "BYTE_H_2", + "BYTE_H_3", - "ACC_H_3", - "ACC_H_2", - "ACC_H_1", - "ACC_H_0", + "COUNTER", + + "EXPONENT_BIT", + "EXPONENT_BIT_ACCUMULATOR", + "EXPONENT_BIT_SOURCE", + "INST", // INSTRUCTION + "MUL_STAMP", + "OLI", // "ONE_LINE_INSTRUCTION", + "RESULT_VANISHES", + + "RES_HI", + "RES_LO", - "EXPONENT_BIT", - "EXPONENT_BIT_ACCUMULATOR", - "EXPONENT_BIT_SOURCE", - "SQUARE_AND_MULTIPLY", - "BIT_NUM", + "SQUARE_AND_MULTIPLY", + "TINY_BASE", + "TINY_EXPONENT", }) @SuppressWarnings("unused") public record Trace( - @JsonProperty("MUL_STAMP") List MUL_STAMP, - @JsonProperty("COUNTER") List COUNTER, - @JsonProperty("ONE_LINE_INSTRUCTION") List ONE_LINE_INSTRUCTION, - @JsonProperty("TINY_BASE") List TINY_BASE, - @JsonProperty("TINY_EXPONENT") List TINY_EXPONENT, - @JsonProperty("RESULT_VANISHES") List RESULT_VANISHES, + @JsonProperty("ACC_A_0") List ACC_A_0, + @JsonProperty("ACC_A_1") List ACC_A_1, + @JsonProperty("ACC_A_2") List ACC_A_2, + @JsonProperty("ACC_A_3") List ACC_A_3, + @JsonProperty("ACC_B_0") List ACC_B_0, + @JsonProperty("ACC_B_1") List ACC_B_1, + @JsonProperty("ACC_B_2") List ACC_B_2, + @JsonProperty("ACC_B_3") List ACC_B_3, + @JsonProperty("ACC_C_0") List ACC_C_0, + @JsonProperty("ACC_C_1") List ACC_C_1, + @JsonProperty("ACC_C_2") List ACC_C_2, + @JsonProperty("ACC_C_3") List ACC_C_3, + @JsonProperty("ACC_H_0") List ACC_H_0, + @JsonProperty("ACC_H_1") List ACC_H_1, + @JsonProperty("ACC_H_2") List ACC_H_2, + @JsonProperty("ACC_H_3") List ACC_H_3, - @JsonProperty("INST") List INST, @JsonProperty("ARG_1_HI") List ARG_1_HI, @JsonProperty("ARG_1_LO") List ARG_1_LO, @JsonProperty("ARG_2_HI") List ARG_2_HI, @JsonProperty("ARG_2_LO") List ARG_2_LO, - @JsonProperty("RES_HI") List RES_HI, - @JsonProperty("RES_LO") List RES_LO, @JsonProperty("BITS") List BITS, - @JsonProperty("BYTE_A_3") List BYTE_A_3, - @JsonProperty("BYTE_A_2") List BYTE_A_2, - @JsonProperty("BYTE_A_1") List BYTE_A_1, + @JsonProperty("BIT_NUM") List BIT_NUM, + @JsonProperty("BYTE_A_0") List BYTE_A_0, - @JsonProperty("ACC_A_3") List ACC_A_3, - @JsonProperty("ACC_A_2") List ACC_A_2, - @JsonProperty("ACC_A_1") List ACC_A_1, - @JsonProperty("ACC_A_4") List ACC_A_0, - @JsonProperty("BYTE_B_3") List BYTE_B_3, - @JsonProperty("BYTE_B_2") List BYTE_B_2, - @JsonProperty("BYTE_B_1") List BYTE_B_1, + @JsonProperty("BYTE_A_1") List BYTE_A_1, + @JsonProperty("BYTE_A_2") List BYTE_A_2, + @JsonProperty("BYTE_A_3") List BYTE_A_3, + @JsonProperty("BYTE_B_0") List BYTE_B_0, - @JsonProperty("ACC_B_3") List ACC_B_3, - @JsonProperty("ACC_B_2") List ACC_B_2, - @JsonProperty("ACC_B_1") List ACC_B_1, - @JsonProperty("ACC_B_4") List ACC_B_0, - @JsonProperty("BYTE_C_3") List BYTE_C_3, - @JsonProperty("BYTE_C_2") List BYTE_C_2, - @JsonProperty("BYTE_C_1") List BYTE_C_1, + @JsonProperty("BYTE_B_1") List BYTE_B_1, + @JsonProperty("BYTE_B_2") List BYTE_B_2, + @JsonProperty("BYTE_B_3") List BYTE_B_3, + @JsonProperty("BYTE_C_0") List BYTE_C_0, - @JsonProperty("ACC_C_3") List ACC_C_3, - @JsonProperty("ACC_C_2") List ACC_C_2, - @JsonProperty("ACC_C_1") List ACC_C_1, - @JsonProperty("ACC_C_4") List ACC_C_0, - @JsonProperty("BYTE_H_3") List BYTE_H_3, - @JsonProperty("BYTE_H_2") List BYTE_H_2, - @JsonProperty("BYTE_H_1") List BYTE_H_1, + @JsonProperty("BYTE_C_1") List BYTE_C_1, + @JsonProperty("BYTE_C_2") List BYTE_C_2, + @JsonProperty("BYTE_C_3") List BYTE_C_3, + @JsonProperty("BYTE_H_0") List BYTE_H_0, - @JsonProperty("ACC_H_3") List ACC_H_3, - @JsonProperty("ACC_H_2") List ACC_H_2, - @JsonProperty("ACC_H_1") List ACC_H_1, - @JsonProperty("ACC_H_4") List ACC_H_0, + @JsonProperty("BYTE_H_1") List BYTE_H_1, + @JsonProperty("BYTE_H_2") List BYTE_H_2, + @JsonProperty("BYTE_H_3") List BYTE_H_3, + + @JsonProperty("COUNTER") List COUNTER, @JsonProperty("EXPONENT_BIT") List EXPONENT_BIT, - @JsonProperty("EXPONENT_BIT_ACCUMULATOR") List EXPONENT_BIT_ACCUMULATOR, + @JsonProperty("EXPONENT_BIT_ACCUMULATOR") List EXPONENT_BIT_ACCUMULATOR, @JsonProperty("EXPONENT_BIT_SOURCE") List EXPONENT_BIT_SOURCE, + + @JsonProperty("INST") List INST, + + @JsonProperty("MUL_STAMP") List MUL_STAMP, + @JsonProperty("ONE_LINE_INSTRUCTION") List ONE_LINE_INSTRUCTION, + @JsonProperty("RESULT_VANISHES") List RESULT_VANISHES, + @JsonProperty("RES_HI") List RES_HI, + @JsonProperty("RES_LO") List RES_LO, @JsonProperty("SQUARE_AND_MULTIPLY") List SQUARE_AND_MULTIPLY, - @JsonProperty("BIT_NUM") List BIT_NUM) { + @JsonProperty("TINY_BASE") List TINY_BASE, + @JsonProperty("TINY_EXPONENT") List TINY_EXPONENT) { public static class Builder { - private final List shiftStamp = new ArrayList<>(); + private final List mulStamp = new ArrayList<>(); private final List counter = new ArrayList<>(); private final List oneLineInstruction = new ArrayList<>(); private final List tinyBase = new ArrayList<>(); @@ -196,7 +208,7 @@ public static class Builder { private final List exponentBit = new ArrayList<>(); // mul.Trace.PushLoBytes(EXPONENT_BIT_ACCUMULATOR.Name(), md.expAcc) // TODO: true ? risky :D - private final List exponentBitAccumulator = new ArrayList<>(); // PushLoBytes + private final List exponentBitAccumulator = new ArrayList<>(); private final List exponentBitSource = new ArrayList<>(); private final List squareAndMultiply = new ArrayList<>(); private final List bitNum = new ArrayList<>(); @@ -429,8 +441,8 @@ public Builder appendResLo(final BigInteger b) { return this; } - public Builder appendShiftStamp(final Integer b) { - shiftStamp.add(b); + public Builder appendStamp(final Integer b) { + mulStamp.add(b); return this; } @@ -442,62 +454,65 @@ public Builder setStamp(final int stamp) { public MulTrace build() { return new MulTrace( new Trace( - shiftStamp, - counter, - oneLineInstruction, - tinyBase, - tinyExponent, - resultVanishes, - inst, + accA0, + accA1, + accA2, + accA3, + accB0, + accB1, + accB2, + accB3, + accC0, + accC1, + accC2, + accC3, + accH0, + accH1, + accH2, + accH3, + arg1Hi, arg1Lo, arg2Hi, arg2Lo, - resHi, - resLo, + bits, + bitNum, - byteA3, - byteA2, - byteA1, byteA0, - accA3, - accA2, - accA1, - accA0, + byteA1, + byteA2, + byteA3, - byteB3, - byteB2, - byteB1, byteB0, - accB3, - accB2, - accB1, - accB0, + byteB1, + byteB2, + byteB3, - byteC3, - byteC2, - byteC1, byteC0, - accC3, - accC2, - accC1, - accC0, + byteC1, + byteC2, + byteC3, - byteH3, - byteH2, - byteH1, byteH0, - accH3, - accH2, - accH1, - accH0, - + byteH1, + byteH2, + byteH3, + + counter, exponentBit, exponentBitAccumulator, exponentBitSource, + + inst, + mulStamp, + oneLineInstruction, + resultVanishes, + resHi, + resLo, squareAndMultiply, - bitNum + tinyBase, + tinyExponent ), stamp); } diff --git a/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java new file mode 100644 index 0000000000..6487b54d6c --- /dev/null +++ b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java @@ -0,0 +1,167 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package net.consensys.zktracer.module.alu.mul; + +import net.consensys.zktracer.OpCode; +import net.consensys.zktracer.bytes.Bytes16; +import net.consensys.zktracer.bytes.BytesBaseTheta; +import net.consensys.zktracer.bytes.UnsignedByte; +import net.consensys.zktracer.module.ModuleTracer; +import org.apache.tuweni.bytes.Bytes; +import org.apache.tuweni.bytes.Bytes32; +import org.apache.tuweni.units.bigints.UInt256; +import org.hyperledger.besu.evm.frame.MessageFrame; + +import java.math.BigInteger; +import java.util.List; + +public class MulTracer implements ModuleTracer { + private static final int MMEDIUM = 8; + + private int stamp = 0; + + @Override + public String jsonKey() { + return "mul"; + } + + @Override + public List supportedOpCodes() { + return List.of(OpCode.MUL, OpCode.EXP); + } + + @Override + public Object trace(MessageFrame frame) { + final Bytes32 arg1 = Bytes32.wrap(frame.getStackItem(0)); + final Bytes32 arg2 = Bytes32.wrap(frame.getStackItem(1)); + + final Bytes16 arg1Hi = Bytes16.wrap(arg1.slice(0, 16)); + final Bytes16 arg1Lo = Bytes16.wrap(arg1.slice(16)); + final Bytes16 arg2Hi = Bytes16.wrap(arg2.slice(0, 16)); + final Bytes16 arg2Lo = Bytes16.wrap(arg2.slice(16)); + + final UInt256 arg1Int = UInt256.fromBytes(arg1); + final UInt256 arg2Int = UInt256.fromBytes(arg2); + final BigInteger arg1BigInt = arg1Int.toUnsignedBigInteger(); + final BigInteger arg2BigInt = arg2Int.toUnsignedBigInteger(); + + final OpCode opCode = OpCode.of(frame.getCurrentOperation().getOpcode()); + + final boolean tinyBase = isTiny(arg1BigInt); + final boolean tinyExponent = isTiny(arg2BigInt); + + final boolean isOneLineInstruction = isOneLineInstruction(tinyBase, tinyExponent); + final Res res = Res.create(opCode, arg1, arg2); + + final Regime regime = getRegime(opCode, tinyBase, tinyExponent, res); + System.out.println(regime); + + final MulTrace.Trace.Builder builder = MulTrace.Trace.Builder.newInstance(); + + final BytesBaseTheta aBytes = new BytesBaseTheta(arg1); + final BytesBaseTheta bBytes = new BytesBaseTheta(arg2); + + + stamp++; + for (int i = 0; i < maxCt(isOneLineInstruction); i++) { + builder.appendStamp(stamp); + builder.appendCounter(i); + + builder + .appendOneLineInstruction(isOneLineInstruction) + .appendTinyBase(tinyBase) + .appendTinyExponent(tinyExponent) + .appendResultVanishes(res.isZero()); + + builder + .appendInst(UnsignedByte.of(opCode.value)) + .appendArg1Hi(arg1Hi.toUnsignedBigInteger()) + .appendArg1Lo(arg1Lo.toUnsignedBigInteger()) + .appendArg2Hi(arg2Hi.toUnsignedBigInteger()) + .appendArg2Lo(arg2Lo.toUnsignedBigInteger()); + + builder + .appendResHi(res.getResHi().toUnsignedBigInteger()) + .appendResLo(res.getResLo().toUnsignedBigInteger()); + +// builder.appendBits(bits.get(i)).appendCounter(i); // TODO + + + builder + .appendByteA3(UnsignedByte.of(aBytes.get(3, i))) + .appendByteA2(UnsignedByte.of(aBytes.get(2, i))) + .appendByteA1(UnsignedByte.of(aBytes.get(1, i))) + .appendByteA0(UnsignedByte.of(aBytes.get(0, i))); + builder + .appendAccA3(Bytes.of(aBytes.getRange(3, 0, i+1)).toUnsignedBigInteger()) + .appendAccA2(Bytes.of(aBytes.getRange(2, 0, i+1)).toUnsignedBigInteger()) + .appendAccA1(Bytes.of(aBytes.getRange(1, 0, i+1)).toUnsignedBigInteger()) + .appendAccA0(Bytes.of(aBytes.getRange(0, 0, i+1)).toUnsignedBigInteger()); + + builder + .appendByteB3(UnsignedByte.of(bBytes.get(3, i))) + .appendByteB2(UnsignedByte.of(bBytes.get(2, i))) + .appendByteB1(UnsignedByte.of(bBytes.get(1, i))) + .appendByteB0(UnsignedByte.of(bBytes.get(0, i))); + builder + .appendAccB3(Bytes.of(bBytes.getRange(3, 0, i+1)).toUnsignedBigInteger()) + .appendAccB2(Bytes.of(bBytes.getRange(2, 0, i+1)).toUnsignedBigInteger()) + .appendAccB1(Bytes.of(bBytes.getRange(1, 0, i+1)).toUnsignedBigInteger()) + .appendAccB0(Bytes.of(bBytes.getRange(0, 0, i+1)).toUnsignedBigInteger()); + + } + builder.setStamp(stamp); + + return builder.build(); + } + + public static boolean isTiny(BigInteger arg) { + return arg.compareTo(BigInteger.valueOf(1)) <= 0; + } + + private int maxCt(final boolean isOneLineInstruction) { + return isOneLineInstruction ? 1 : MMEDIUM; + } + + private boolean isOneLineInstruction(final boolean tinyBase, final boolean tinyExponent) { + return tinyBase || tinyExponent; + } + + private enum Regime { + IOTA, + TRIVIAL_MUL, + NON_TRIVIAL_MUL, + EXPONENT_ZERO_RESULT, + EXPONENT_NON_ZERO_RESULT + } + private Regime getRegime(final OpCode opCode, final boolean tinyBase, final boolean tinyExponent, final Res res) { + + if (isOneLineInstruction(tinyBase, tinyExponent)) return Regime.TRIVIAL_MUL; + + + if (OpCode.MUL.equals(opCode)) { + return Regime.NON_TRIVIAL_MUL; + } + + if (OpCode.EXP.equals(opCode)) { + if (res.isZero()) { + return Regime.EXPONENT_ZERO_RESULT; + } else { + return Regime.EXPONENT_NON_ZERO_RESULT; + } + } + return Regime.IOTA; + } +} diff --git a/src/main/java/net/consensys/zktracer/module/alu/mul/Muler.java b/src/main/java/net/consensys/zktracer/module/alu/mul/Muler.java new file mode 100644 index 0000000000..2f148c129c --- /dev/null +++ b/src/main/java/net/consensys/zktracer/module/alu/mul/Muler.java @@ -0,0 +1,32 @@ +package net.consensys.zktracer.module.alu.mul; +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +import net.consensys.zktracer.OpCode; +import org.apache.tuweni.bytes.Bytes; +import org.apache.tuweni.bytes.Bytes32; +import org.apache.tuweni.units.bigints.UInt256; + +import java.math.BigInteger; + +public class Muler { + + public static UInt256 operate(final OpCode opCode, final Bytes32 arg1, final Bytes32 arg2) { + return switch (opCode) { + case MUL -> UInt256.fromBytes(arg1).multiply(UInt256.fromBytes(arg2)); + case EXP -> UInt256.fromBytes(arg1).pow(UInt256.fromBytes(arg2)); + default -> UInt256.ZERO; + }; + } +} diff --git a/src/main/java/net/consensys/zktracer/module/alu/mul/Res.java b/src/main/java/net/consensys/zktracer/module/alu/mul/Res.java new file mode 100644 index 0000000000..f3074f7b21 --- /dev/null +++ b/src/main/java/net/consensys/zktracer/module/alu/mul/Res.java @@ -0,0 +1,53 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package net.consensys.zktracer.module.alu.mul; + +import net.consensys.zktracer.OpCode; +import net.consensys.zktracer.bytes.Bytes16; +import net.consensys.zktracer.module.shf.Shifter; +import org.apache.tuweni.bytes.Bytes32; + +import java.math.BigInteger; + +public class Res { + final Bytes16 resHi; + final Bytes16 resLo; + final boolean isZero; + + private Res(Bytes16 resHi, Bytes16 resLo, boolean isZero) { + this.resHi = resHi; + this.resLo = resLo; + this.isZero = isZero; + } + + public Bytes16 getResHi() { + return resHi; + } + + public Bytes16 getResLo() { + return resLo; + } + + public static Res create(final OpCode opCode, final Bytes32 arg1, final Bytes32 arg2) { + final Bytes32 result = Muler.operate(opCode, arg1, arg2); + + return new Res(Bytes16.wrap(result.slice(0, 16)), Bytes16.wrap(result.slice(16)), result.isZero()); + } + + public boolean isZero() { + return isZero; + } + +} diff --git a/src/main/java/net/consensys/zktracer/module/shf/Shifter.java b/src/main/java/net/consensys/zktracer/module/shf/Shifter.java index 81186100ca..d08d43bafb 100644 --- a/src/main/java/net/consensys/zktracer/module/shf/Shifter.java +++ b/src/main/java/net/consensys/zktracer/module/shf/Shifter.java @@ -26,6 +26,7 @@ public static Bytes32 shift(final OpCode opCode, final Bytes32 value, final int case SHR -> value.shiftRight(shiftAmount); case SHL -> value.shiftLeft(shiftAmount); case SAR -> sarOperation(value, shiftAmount); + default -> Bytes32.ZERO; }; } diff --git a/src/test/java/net/consensys/zktracer/module/alu/mul/MulTracerTest.java b/src/test/java/net/consensys/zktracer/module/alu/mul/MulTracerTest.java new file mode 100644 index 0000000000..7d2f732343 --- /dev/null +++ b/src/test/java/net/consensys/zktracer/module/alu/mul/MulTracerTest.java @@ -0,0 +1,174 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package net.consensys.zktracer.module.alu.mul; + +import net.consensys.zktracer.CorsetValidator; +import net.consensys.zktracer.OpCode; +import net.consensys.zktracer.ZkTraceBuilder; +import net.consensys.zktracer.ZkTracer; +import org.apache.tuweni.bytes.Bytes; +import org.apache.tuweni.bytes.Bytes32; +import org.hyperledger.besu.evm.frame.MessageFrame; +import org.hyperledger.besu.evm.operation.Operation; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Named; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.math.BigInteger; +import java.util.Random; +import java.util.stream.Stream; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class MulTracerTest { + private static final Logger LOG = LoggerFactory.getLogger(MulTracerTest.class); + + private static final Random rand = new Random(); + private static final int TEST_REPETITIONS = 4; + + private ZkTracer zkTracer; + private ZkTraceBuilder zkTraceBuilder; + + @Mock MessageFrame mockFrame; + @Mock Operation mockOperation; + + @BeforeEach + void setUp() { + zkTraceBuilder = new ZkTraceBuilder(); + zkTracer = new ZkTracer(zkTraceBuilder); + + when(mockFrame.getCurrentOperation()).thenReturn(mockOperation); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideMulOperators") + void testFailingBlockchainBlock(final int opCodeValue) { + when(mockOperation.getOpcode()).thenReturn(opCodeValue); + + when(mockFrame.getStackItem(0)).thenReturn(Bytes32.rightPad(Bytes.fromHexString("0x08"))); + when(mockFrame.getStackItem(1)).thenReturn(Bytes32.fromHexString("0x0128")); + + zkTracer.tracePreExecution(mockFrame); + + assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideRandomArguments") + void testRandomExp(final Bytes32[] payload) { + LOG.info( + "arg1: " + payload[0].toShortHexString() + ", arg2: " + payload[1].toShortHexString()); + when(mockOperation.getOpcode()).thenReturn((int) OpCode.EXP.value); + + when(mockFrame.getStackItem(0)).thenReturn(payload[0]); + when(mockFrame.getStackItem(1)).thenReturn(payload[1]); + + zkTracer.tracePreExecution(mockFrame); + + assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideNonRandomArguments") + void testNonRandomMul(final Bytes32[] payload) { + LOG.info( + "arg1: " + payload[0].toShortHexString() + ", arg2: " + payload[1].toShortHexString()); + when(mockOperation.getOpcode()).thenReturn((int) OpCode.EXP.value); + + when(mockFrame.getStackItem(0)).thenReturn(payload[0]); + when(mockFrame.getStackItem(1)).thenReturn(payload[1]); + + zkTracer.tracePreExecution(mockFrame); + + assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + } + + @Test + void testTmp() { + when(mockOperation.getOpcode()).thenReturn((int) OpCode.MUL.value); + + when(mockFrame.getStackItem(0)) + .thenReturn(Bytes32.fromHexStringLenient("0x54fda4f3c1452c8c58df4fb1e9d6de")); + when(mockFrame.getStackItem(1)).thenReturn(Bytes32.fromHexStringLenient("0xb5")); + + zkTracer.tracePreExecution(mockFrame); + + assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + } + + + public static Stream provideNonRandomArguments() { + final Arguments[] arguments = new Arguments[TEST_REPETITIONS]; + + for (int i = 0; i < TEST_REPETITIONS; i++) { + Bytes32[] payload = new Bytes32[2]; + payload[0] = Bytes32.leftPad(Bytes.of(1+i)); + payload[1] = Bytes32.leftPad(Bytes.of(i)); + arguments[i] = + Arguments.of( + Named.of( + "arg1: " + + payload[0] + + ", arg2: " + + payload[1], + payload)); + } + + return Stream.of(arguments); + } + + public static Stream provideRandomArguments() { + final Arguments[] arguments = new Arguments[TEST_REPETITIONS]; + + for (int i = 0; i < TEST_REPETITIONS; i++) { + + final byte[] randomBytes1 = new byte[32]; + rand.nextBytes(randomBytes1); + final byte[] randomBytes2 = new byte[32]; + rand.nextBytes(randomBytes2); + + Bytes32[] payload = new Bytes32[2]; + payload[0] = Bytes32.wrap(randomBytes1); + payload[1] = Bytes32.wrap(randomBytes2); + + arguments[i] = + Arguments.of( + Named.of( + "arg1: " + + payload[0].toHexString() + + ", arg2: " + + payload[1].toHexString(), + payload)); + } + + return Stream.of(arguments); + } + public static Stream provideMulOperators() { + return Stream.of( + Arguments.of(Named.of("MUL", (int) OpCode.MUL.value)), + Arguments.of(Named.of("EXP", (int) OpCode.EXP.value))); + } + +} diff --git a/src/test/java/net/consensys/zktracer/module/alu/mul/MulUtilsTest.java b/src/test/java/net/consensys/zktracer/module/alu/mul/MulUtilsTest.java new file mode 100644 index 0000000000..81834857a3 --- /dev/null +++ b/src/test/java/net/consensys/zktracer/module/alu/mul/MulUtilsTest.java @@ -0,0 +1,17 @@ +package net.consensys.zktracer.module.alu.mul; + +import org.junit.jupiter.api.Test; + +import java.math.BigInteger; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +public class MulUtilsTest { + @Test + public void isTiny() { + assertThat(MulTracer.isTiny(BigInteger.ZERO)).isTrue(); + assertThat(MulTracer.isTiny(BigInteger.ONE)).isTrue(); + assertThat(MulTracer.isTiny(BigInteger.TWO)).isFalse(); + assertThat(MulTracer.isTiny(BigInteger.TEN)).isFalse(); + } +} From 230137f41dbe0c7bacc28343e3ee69cc7a21991b Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Thu, 30 Mar 2023 14:46:13 +1000 Subject: [PATCH 03/31] added lineCount method Signed-off-by: Sally MacFarlane --- .../zktracer/module/alu/mul/MulTracer.java | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java index 6487b54d6c..3e3086c3ea 100644 --- a/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java @@ -164,4 +164,18 @@ private Regime getRegime(final OpCode opCode, final boolean tinyBase, final bool } return Regime.IOTA; } + + public int lineCount(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { + + + final UInt256 arg1Int = UInt256.fromBytes(arg1); + final UInt256 arg2Int = UInt256.fromBytes(arg2); + final BigInteger arg1BigInt = arg1Int.toUnsignedBigInteger(); + final BigInteger arg2BigInt = arg2Int.toUnsignedBigInteger(); + + final boolean tinyBase = isTiny(arg1BigInt); + final boolean tinyExponent = isTiny(arg2BigInt); + + return maxCt(isOneLineInstruction(tinyBase, tinyExponent)); + } } From c1625681731b9b712de84593ec853f48fe38b15a Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Fri, 31 Mar 2023 10:45:27 +1000 Subject: [PATCH 04/31] added a test from go with specific byte32 values Signed-off-by: Sally MacFarlane --- .../module/alu/mul/MulTracerTest.java | 41 +++++++++++++++++-- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/src/test/java/net/consensys/zktracer/module/alu/mul/MulTracerTest.java b/src/test/java/net/consensys/zktracer/module/alu/mul/MulTracerTest.java index 7d2f732343..c95ff2aa5b 100644 --- a/src/test/java/net/consensys/zktracer/module/alu/mul/MulTracerTest.java +++ b/src/test/java/net/consensys/zktracer/module/alu/mul/MulTracerTest.java @@ -34,7 +34,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.math.BigInteger; import java.util.Random; import java.util.stream.Stream; @@ -91,8 +90,23 @@ void testRandomExp(final Bytes32[] payload) { } @ParameterizedTest(name = "{0}") - @MethodSource("provideNonRandomArguments") - void testNonRandomMul(final Bytes32[] payload) { + @MethodSource("provideNonRandomTinyArguments") + void testNonRandomTinyMul(final Bytes32[] payload) { + LOG.info( + "arg1: " + payload[0].toShortHexString() + ", arg2: " + payload[1].toShortHexString()); + when(mockOperation.getOpcode()).thenReturn((int) OpCode.EXP.value); + + when(mockFrame.getStackItem(0)).thenReturn(payload[0]); + when(mockFrame.getStackItem(1)).thenReturn(payload[1]); + + zkTracer.tracePreExecution(mockFrame); + + assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideNonRandomNonTinyArguments") + void testNonRandomNonTinyMul(final Bytes32[] payload) { LOG.info( "arg1: " + payload[0].toShortHexString() + ", arg2: " + payload[1].toShortHexString()); when(mockOperation.getOpcode()).thenReturn((int) OpCode.EXP.value); @@ -118,8 +132,27 @@ void testTmp() { assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); } + public static Stream provideNonRandomNonTinyArguments() { +// these values are used in Go module test +// 0x8a, 0x48, 0xaa, 0x20, 0xe2, 0x00, 0xce, 0x3f, 0xee, 0x16, 0xb5, 0xdc, 0xde, 0xc5, 0xc4, 0xfa, +// 0xff, 0x61, 0x3b, 0xc9, 0x14, 0xd4, 0x7c, 0xd6, 0xca, 0x69, 0x55, 0x3f, 0x8e, 0xb2, 0xb3, 0x77, +// byte(vm.PUSH32), +// 0x59, 0xb6, 0x35, 0xfe, 0xc8, 0x94, 0xca, 0xa3, 0xed, 0x68, 0x17, 0xb1, 0xe6, 0x7b, 0x3c, 0xba, +// 0xeb, 0x87, 0x57, 0xfd, 0x6c, 0x7b, 0x03, 0x11, 0x9b, 0x79, 0x53, 0x03, 0xb7, 0xcd, 0x72, 0xc1, + final Bytes32[] payload = new Bytes32[2]; + payload[0] = Bytes32.fromHexString("0x8a48aa20e200ce3fee16b5dcdec5c4faff613bc914d47cd6ca69553f8eb2b377"); + payload[1] = Bytes32.fromHexString("0x59b635fec894caa3ed6817b1e67b3cbaeb8757fd6c7b03119b795303b7cd72c1"); + return Stream.of( + Arguments.of( + Named.of( + "arg1: " + + payload[0] + + ", arg2: " + + payload[1], + payload))); + } - public static Stream provideNonRandomArguments() { + public static Stream provideNonRandomTinyArguments() { final Arguments[] arguments = new Arguments[TEST_REPETITIONS]; for (int i = 0; i < TEST_REPETITIONS; i++) { From c0037b2232d1f2ddda1bc3ce0717b08e14786547 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 12 Apr 2023 07:10:47 +1000 Subject: [PATCH 05/31] bytesTheta Signed-off-by: Sally MacFarlane --- .../zktracer/bytes/BytesBaseTheta.java | 15 ++++++++ .../zktracer/module/alu/mul/MulTracer.java | 34 +++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/main/java/net/consensys/zktracer/bytes/BytesBaseTheta.java b/src/main/java/net/consensys/zktracer/bytes/BytesBaseTheta.java index bf467ebf99..66c8512e61 100644 --- a/src/main/java/net/consensys/zktracer/bytes/BytesBaseTheta.java +++ b/src/main/java/net/consensys/zktracer/bytes/BytesBaseTheta.java @@ -1,5 +1,6 @@ package net.consensys.zktracer.bytes; +import net.consensys.zktracer.module.alu.mul.Res; import org.apache.tuweni.bytes.Bytes32; import java.math.BigInteger; @@ -20,6 +21,20 @@ public BytesBaseTheta(final Bytes32 arg) { } } + public BytesBaseTheta(final Res res) { + bytes = new byte[4][8]; + byte[] argBytesHi = res.getResHi().toArray(); + byte[] argBytesLo = res.getResLo().toArray(); + + for (int k = 0; k < 2; k++) { + System.arraycopy(argBytesHi, 8 * k, bytes[3 - k], 0, 8); + } + for (int k = 2; k < 4; k++) { + System.arraycopy(argBytesLo, 8 * k, bytes[3 - k], 0, 8); + } + } + + // TODO can Res become Pair as below public Pair getHiLo() { byte[] hiBytes = new byte[16]; byte[] loBytes = new byte[16]; diff --git a/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java index 3e3086c3ea..7a0d2036a4 100644 --- a/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java @@ -65,13 +65,27 @@ public Object trace(MessageFrame frame) { final boolean isOneLineInstruction = isOneLineInstruction(tinyBase, tinyExponent); final Res res = Res.create(opCode, arg1, arg2); - final Regime regime = getRegime(opCode, tinyBase, tinyExponent, res); - System.out.println(regime); - final MulTrace.Trace.Builder builder = MulTrace.Trace.Builder.newInstance(); final BytesBaseTheta aBytes = new BytesBaseTheta(arg1); final BytesBaseTheta bBytes = new BytesBaseTheta(arg2); + BytesBaseTheta cBytes ; + BytesBaseTheta hBytes ; + boolean snm = false; + + final Regime regime = getRegime(opCode, tinyBase, tinyExponent, res); + System.out.println(regime); + switch (regime) { + case TRIVIAL_MUL: break; + case NON_TRIVIAL_MUL: + cBytes = new BytesBaseTheta(res); + case EXPONENT_ZERO_RESULT: + setArraysForZeroResultCase(); + case EXPONENT_NON_ZERO_RESULT: + setExponentBit(); + snm = false; + case IOTA: throw new RuntimeException("alu/mul regime was never set"); + } stamp++; @@ -120,6 +134,11 @@ public Object trace(MessageFrame frame) { .appendAccB2(Bytes.of(bBytes.getRange(2, 0, i+1)).toUnsignedBigInteger()) .appendAccB1(Bytes.of(bBytes.getRange(1, 0, i+1)).toUnsignedBigInteger()) .appendAccB0(Bytes.of(bBytes.getRange(0, 0, i+1)).toUnsignedBigInteger()); + builder + .appendByteC3(UnsignedByte.of(cBytes.get(3, i))) + .appendByteC2(UnsignedByte.of(cBytes.get(2, i))) + .appendByteC1(UnsignedByte.of(cBytes.get(1, i))) + .appendByteC0(UnsignedByte.of(cBytes.get(0, i))); } builder.setStamp(stamp); @@ -127,6 +146,15 @@ public Object trace(MessageFrame frame) { return builder.build(); } + private void setArraysForZeroResultCase() { + // TODO + } + private boolean setExponentBit() { + // TODO + return false; +// return string(exponentBits[md.index]) == "1"; + } + public static boolean isTiny(BigInteger arg) { return arg.compareTo(BigInteger.valueOf(1)) <= 0; } From 677078d09c1200ea7a2fa0286a3b680f3f00704e Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 12 Apr 2023 07:11:02 +1000 Subject: [PATCH 06/31] initialize cBytes Signed-off-by: Sally MacFarlane --- .../java/net/consensys/zktracer/module/alu/mul/MulTracer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java index 7a0d2036a4..a84385f8c7 100644 --- a/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/zktracer/module/alu/mul/MulTracer.java @@ -69,7 +69,7 @@ public Object trace(MessageFrame frame) { final BytesBaseTheta aBytes = new BytesBaseTheta(arg1); final BytesBaseTheta bBytes = new BytesBaseTheta(arg2); - BytesBaseTheta cBytes ; + BytesBaseTheta cBytes = null; BytesBaseTheta hBytes ; boolean snm = false; From 77ebccc78858b4bf65809c4ca46a49f05fe5ece7 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 12 Apr 2023 10:07:17 +1000 Subject: [PATCH 07/31] rename test Signed-off-by: Sally MacFarlane --- .../consensys/linea/zktracer/module/alu/mul/MulTracerTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java index 74d29b269b..bdc92029d7 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java @@ -119,7 +119,7 @@ void testNonRandomNonTinyMul(final Bytes32[] payload) { } @Test - void testTmp() { + void testSimpleMul() { when(mockOperation.getOpcode()).thenReturn((int) OpCode.MUL.value); when(mockFrame.getStackItem(0)) From 2006729312235c147a04a9a517c7dea78b3ccc8f Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 12 Apr 2023 10:27:24 +1000 Subject: [PATCH 08/31] refactor with MulData class Signed-off-by: Sally MacFarlane --- .../zktracer/module/alu/mul/MulData.java | 105 +++++++++++++ .../zktracer/module/alu/mul/MulTracer.java | 146 ++++-------------- .../zktracer/module/alu/mul/MulUtilsTest.java | 8 +- 3 files changed, 136 insertions(+), 123 deletions(-) create mode 100644 src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java new file mode 100644 index 0000000000..b439c63176 --- /dev/null +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -0,0 +1,105 @@ +package net.consensys.linea.zktracer.module.alu.mul; + +import net.consensys.linea.zktracer.OpCode; +import net.consensys.linea.zktracer.bytes.BytesBaseTheta; +import org.apache.tuweni.bytes.Bytes32; +import org.apache.tuweni.units.bigints.UInt256; + +import java.math.BigInteger; + +@SuppressWarnings("UnusedVariable") +public class MulData { + final OpCode opCode; + final boolean tinyBase; + final boolean tinyExponent; + + final BytesBaseTheta aBytes; + final BytesBaseTheta bBytes; + BytesBaseTheta cBytes; + BytesBaseTheta hBytes; + boolean snm = false; + Res res; + public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { + + this.opCode = opCode; + this.aBytes = new BytesBaseTheta(arg1); + this.bBytes = new BytesBaseTheta(arg2); + this.cBytes = null; + this.hBytes = null; + boolean snm = false; + + this.res = Res.create(opCode, arg1, arg2); // TODO can we get this from the EVM + + + final UInt256 arg1Int = UInt256.fromBytes(arg1); + final UInt256 arg2Int = UInt256.fromBytes(arg2); + final BigInteger arg1BigInt = arg1Int.toUnsignedBigInteger(); + final BigInteger arg2BigInt = arg2Int.toUnsignedBigInteger(); + + this.tinyBase = isTiny(arg1BigInt); + this.tinyExponent = isTiny(arg2BigInt); + + final Regime regime = getRegime(opCode); + System.out.println(regime); + switch (regime) { + case TRIVIAL_MUL: + break; + case NON_TRIVIAL_MUL: + cBytes = new BytesBaseTheta(res); + break; + case EXPONENT_ZERO_RESULT: + setArraysForZeroResultCase(); + break; + case EXPONENT_NON_ZERO_RESULT: + setExponentBit(); + snm = false; + break; + case IOTA: + throw new RuntimeException("alu/mul regime was never set"); + } + } + + private void setArraysForZeroResultCase() { + // TODO + } + + private boolean setExponentBit() { + // TODO + return false; + // return string(exponentBits[md.index]) == "1"; + } + + private enum Regime { + IOTA, + TRIVIAL_MUL, + NON_TRIVIAL_MUL, + EXPONENT_ZERO_RESULT, + EXPONENT_NON_ZERO_RESULT + } + + public boolean isOneLineInstruction() { + return tinyBase || tinyExponent; + } + + private Regime getRegime( + final OpCode opCode) { + + if (isOneLineInstruction()) return Regime.TRIVIAL_MUL; + + if (OpCode.MUL.equals(opCode)) { + return Regime.NON_TRIVIAL_MUL; + } + + if (OpCode.EXP.equals(opCode)) { + if (res.isZero()) { + return Regime.EXPONENT_ZERO_RESULT; + } else { + return Regime.EXPONENT_NON_ZERO_RESULT; + } + } + return Regime.IOTA; + } + public static boolean isTiny(BigInteger arg) { + return arg.compareTo(BigInteger.valueOf(1)) <= 0; + } +} diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java index 4d2af4526e..be9f92b3ce 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java @@ -21,7 +21,6 @@ import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.bytes.Bytes16; -import net.consensys.linea.zktracer.bytes.BytesBaseTheta; import net.consensys.linea.zktracer.bytes.UnsignedByte; import net.consensys.linea.zktracer.module.ModuleTracer; import org.apache.tuweni.bytes.Bytes; @@ -54,45 +53,11 @@ public Object trace(MessageFrame frame) { final Bytes16 arg2Hi = Bytes16.wrap(arg2.slice(0, 16)); final Bytes16 arg2Lo = Bytes16.wrap(arg2.slice(16)); - final UInt256 arg1Int = UInt256.fromBytes(arg1); - final UInt256 arg2Int = UInt256.fromBytes(arg2); - final BigInteger arg1BigInt = arg1Int.toUnsignedBigInteger(); - final BigInteger arg2BigInt = arg2Int.toUnsignedBigInteger(); - final OpCode opCode = OpCode.of(frame.getCurrentOperation().getOpcode()); - final boolean tinyBase = isTiny(arg1BigInt); - final boolean tinyExponent = isTiny(arg2BigInt); - - final boolean isOneLineInstruction = isOneLineInstruction(tinyBase, tinyExponent); - final Res res = Res.create(opCode, arg1, arg2); - + final MulData data = new MulData(opCode, arg1, arg2); final MulTrace.Trace.Builder builder = MulTrace.Trace.Builder.newInstance(); - - final BytesBaseTheta aBytes = new BytesBaseTheta(arg1); - final BytesBaseTheta bBytes = new BytesBaseTheta(arg2); - BytesBaseTheta cBytes = null; - BytesBaseTheta hBytes; - boolean snm = false; - - final Regime regime = getRegime(opCode, tinyBase, tinyExponent, res); - System.out.println(regime); - switch (regime) { - case TRIVIAL_MUL: - break; - case NON_TRIVIAL_MUL: - cBytes = new BytesBaseTheta(res); - break; - case EXPONENT_ZERO_RESULT: - setArraysForZeroResultCase(); - break; - case EXPONENT_NON_ZERO_RESULT: - setExponentBit(); - snm = false; - break; - case IOTA: - throw new RuntimeException("alu/mul regime was never set"); - } + final boolean isOneLineInstruction = data.isOneLineInstruction(); stamp++; for (int i = 0; i < maxCt(isOneLineInstruction); i++) { @@ -101,9 +66,9 @@ public Object trace(MessageFrame frame) { builder .appendOneLineInstruction(isOneLineInstruction) - .appendTinyBase(tinyBase) - .appendTinyExponent(tinyExponent) - .appendResultVanishes(res.isZero()); + .appendTinyBase(data.tinyBase) + .appendTinyExponent(data.tinyExponent) + .appendResultVanishes(data.res.isZero()); builder .appendInst(UnsignedByte.of(opCode.value)) @@ -113,102 +78,45 @@ public Object trace(MessageFrame frame) { .appendArg2Lo(arg2Lo.toUnsignedBigInteger()); builder - .appendResHi(res.getResHi().toUnsignedBigInteger()) - .appendResLo(res.getResLo().toUnsignedBigInteger()); + .appendResHi(data.res.getResHi().toUnsignedBigInteger()) + .appendResLo(data.res.getResLo().toUnsignedBigInteger()); // builder.appendBits(bits.get(i)).appendCounter(i); // TODO builder - .appendByteA3(UnsignedByte.of(aBytes.get(3, i))) - .appendByteA2(UnsignedByte.of(aBytes.get(2, i))) - .appendByteA1(UnsignedByte.of(aBytes.get(1, i))) - .appendByteA0(UnsignedByte.of(aBytes.get(0, i))); + .appendByteA3(UnsignedByte.of(data.aBytes.get(3, i))) + .appendByteA2(UnsignedByte.of(data.aBytes.get(2, i))) + .appendByteA1(UnsignedByte.of(data.aBytes.get(1, i))) + .appendByteA0(UnsignedByte.of(data.aBytes.get(0, i))); builder - .appendAccA3(Bytes.of(aBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccA2(Bytes.of(aBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccA1(Bytes.of(aBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccA0(Bytes.of(aBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + .appendAccA3(Bytes.of(data.aBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) + .appendAccA2(Bytes.of(data.aBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) + .appendAccA1(Bytes.of(data.aBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) + .appendAccA0(Bytes.of(data.aBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); builder - .appendByteB3(UnsignedByte.of(bBytes.get(3, i))) - .appendByteB2(UnsignedByte.of(bBytes.get(2, i))) - .appendByteB1(UnsignedByte.of(bBytes.get(1, i))) - .appendByteB0(UnsignedByte.of(bBytes.get(0, i))); + .appendByteB3(UnsignedByte.of(data.bBytes.get(3, i))) + .appendByteB2(UnsignedByte.of(data.bBytes.get(2, i))) + .appendByteB1(UnsignedByte.of(data.bBytes.get(1, i))) + .appendByteB0(UnsignedByte.of(data.bBytes.get(0, i))); builder - .appendAccB3(Bytes.of(bBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB2(Bytes.of(bBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB1(Bytes.of(bBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB0(Bytes.of(bBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + .appendAccB3(Bytes.of(data.bBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB2(Bytes.of(data.bBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB1(Bytes.of(data.bBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB0(Bytes.of(data.bBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); builder - .appendByteC3(UnsignedByte.of(cBytes.get(3, i))) - .appendByteC2(UnsignedByte.of(cBytes.get(2, i))) - .appendByteC1(UnsignedByte.of(cBytes.get(1, i))) - .appendByteC0(UnsignedByte.of(cBytes.get(0, i))); + .appendByteC3(UnsignedByte.of(data.cBytes.get(3, i))) + .appendByteC2(UnsignedByte.of(data.cBytes.get(2, i))) + .appendByteC1(UnsignedByte.of(data.cBytes.get(1, i))) + .appendByteC0(UnsignedByte.of(data.cBytes.get(0, i))); } builder.setStamp(stamp); return builder.build(); } - private void setArraysForZeroResultCase() { - // TODO - } - - private boolean setExponentBit() { - // TODO - return false; - // return string(exponentBits[md.index]) == "1"; - } - - public static boolean isTiny(BigInteger arg) { - return arg.compareTo(BigInteger.valueOf(1)) <= 0; - } private int maxCt(final boolean isOneLineInstruction) { return isOneLineInstruction ? 1 : MMEDIUM; } - - private boolean isOneLineInstruction(final boolean tinyBase, final boolean tinyExponent) { - return tinyBase || tinyExponent; - } - - private enum Regime { - IOTA, - TRIVIAL_MUL, - NON_TRIVIAL_MUL, - EXPONENT_ZERO_RESULT, - EXPONENT_NON_ZERO_RESULT - } - - private Regime getRegime( - final OpCode opCode, final boolean tinyBase, final boolean tinyExponent, final Res res) { - - if (isOneLineInstruction(tinyBase, tinyExponent)) return Regime.TRIVIAL_MUL; - - if (OpCode.MUL.equals(opCode)) { - return Regime.NON_TRIVIAL_MUL; - } - - if (OpCode.EXP.equals(opCode)) { - if (res.isZero()) { - return Regime.EXPONENT_ZERO_RESULT; - } else { - return Regime.EXPONENT_NON_ZERO_RESULT; - } - } - return Regime.IOTA; - } - - public int lineCount(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { - - final UInt256 arg1Int = UInt256.fromBytes(arg1); - final UInt256 arg2Int = UInt256.fromBytes(arg2); - final BigInteger arg1BigInt = arg1Int.toUnsignedBigInteger(); - final BigInteger arg2BigInt = arg2Int.toUnsignedBigInteger(); - - final boolean tinyBase = isTiny(arg1BigInt); - final boolean tinyExponent = isTiny(arg2BigInt); - - return maxCt(isOneLineInstruction(tinyBase, tinyExponent)); - } } diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java index b69f177de9..e5f4ebb985 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java @@ -9,9 +9,9 @@ public class MulUtilsTest { @Test public void isTiny() { - assertThat(MulTracer.isTiny(BigInteger.ZERO)).isTrue(); - assertThat(MulTracer.isTiny(BigInteger.ONE)).isTrue(); - assertThat(MulTracer.isTiny(BigInteger.TWO)).isFalse(); - assertThat(MulTracer.isTiny(BigInteger.TEN)).isFalse(); + assertThat(MulData.isTiny(BigInteger.ZERO)).isTrue(); + assertThat(MulData.isTiny(BigInteger.ONE)).isTrue(); + assertThat(MulData.isTiny(BigInteger.TWO)).isFalse(); + assertThat(MulData.isTiny(BigInteger.TEN)).isFalse(); } } From 3e77c4d3a671ce7732b3ba7fd9b339ef713ce6c1 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 12 Apr 2023 12:37:25 +1000 Subject: [PATCH 09/31] wip Signed-off-by: Sally MacFarlane --- .../zktracer/module/alu/mul/MulData.java | 43 ++++++++++++++++--- .../zktracer/module/alu/mul/MulTrace.java | 27 ++++++++++++ .../zktracer/module/alu/mul/MulTracer.java | 23 +++++++++- 3 files changed, 86 insertions(+), 7 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index b439c63176..77d8b47299 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -13,20 +13,31 @@ public class MulData { final boolean tinyBase; final boolean tinyExponent; + final UInt256 resAcc; // accumulator which converges in a series of "square and multiply"'s + final UInt256 expAcc; // accumulator for the doubles and adds of the exponent, resets at some point + final BytesBaseTheta aBytes; final BytesBaseTheta bBytes; BytesBaseTheta cBytes; BytesBaseTheta hBytes; boolean snm = false; + int index; + boolean[] bits; + String exponentBits; + Res res; public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { this.opCode = opCode; this.aBytes = new BytesBaseTheta(arg1); this.bBytes = new BytesBaseTheta(arg2); + + // TODO what should these be initialized to this.cBytes = null; this.hBytes = null; boolean snm = false; + this.resAcc = UInt256.MIN_VALUE; + this.expAcc = UInt256.MIN_VALUE; this.res = Res.create(opCode, arg1, arg2); // TODO can we get this from the EVM @@ -51,7 +62,7 @@ public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { setArraysForZeroResultCase(); break; case EXPONENT_NON_ZERO_RESULT: - setExponentBit(); + this.exponentBits = arg2.toBigInteger().toString(); snm = false; break; case IOTA: @@ -63,10 +74,16 @@ private void setArraysForZeroResultCase() { // TODO } - private boolean setExponentBit() { - // TODO - return false; - // return string(exponentBits[md.index]) == "1"; + public boolean exponentBit() { + return '1' == exponentBits.charAt(index); + } + + public boolean exponentSource() { + return this.index + 128 >= exponentBits.length(); + } + + private boolean largeExponent() { + return exponentBits.length() > 128; } private enum Regime { @@ -102,4 +119,20 @@ private Regime getRegime( public static boolean isTiny(BigInteger arg) { return arg.compareTo(BigInteger.valueOf(1)) <= 0; } + + public int getBitNum() { + return bitNum(index, exponentBits.length()); + } + + private int bitNum(int i, int length) { + if (length <= 128) { + return i; + } else { + if (i+128 < length) { + return i; + } else { + return i + 128 - length; + } + } + } } diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java index 5301846f81..1bc5ceeb97 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java @@ -423,6 +423,33 @@ public Builder appendResLo(final BigInteger b) { return this; } + // + + public Builder appendExponentBit(final Boolean b) { + exponentBit.add(b); + return this; + } + + public Builder appendExponentBitAcc(final UnsignedByte b) { + exponentBitAccumulator.add(b); + return this; + } + + public Builder appendExponentBitSource(final Boolean b) { + exponentBitSource.add(b); + return this; + } + + public Builder appendSquareAndMultiply(final Boolean b) { + squareAndMultiply.add(b); + return this; + } + + public Builder appendExponentBitNum(final Integer b) { + bitNum.add(b); + return this; + } + public Builder appendStamp(final Integer b) { mulStamp.add(b); return this; diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java index be9f92b3ce..d043c0f0e8 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java @@ -16,7 +16,6 @@ import org.hyperledger.besu.evm.frame.MessageFrame; -import java.math.BigInteger; import java.util.List; import net.consensys.linea.zktracer.OpCode; @@ -25,7 +24,6 @@ import net.consensys.linea.zktracer.module.ModuleTracer; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; -import org.apache.tuweni.units.bigints.UInt256; public class MulTracer implements ModuleTracer { private static final int MMEDIUM = 8; @@ -109,6 +107,27 @@ public Object trace(MessageFrame frame) { .appendByteC2(UnsignedByte.of(data.cBytes.get(2, i))) .appendByteC1(UnsignedByte.of(data.cBytes.get(1, i))) .appendByteC0(UnsignedByte.of(data.cBytes.get(0, i))); + builder + .appendAccB3(Bytes.of(data.cBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB2(Bytes.of(data.cBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB1(Bytes.of(data.cBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB0(Bytes.of(data.cBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + + builder + .appendByteH3(UnsignedByte.of(data.hBytes.get(3, i))) + .appendByteH2(UnsignedByte.of(data.hBytes.get(2, i))) + .appendByteH1(UnsignedByte.of(data.hBytes.get(1, i))) + .appendByteH0(UnsignedByte.of(data.hBytes.get(0, i))); + builder + .appendAccB3(Bytes.of(data.hBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB2(Bytes.of(data.hBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB1(Bytes.of(data.hBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB0(Bytes.of(data.hBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + builder.appendExponentBit(data.exponentBit()) + .appendExponentBitAcc(data.expAcc) + .appendExponentBitSource(data.exponentSource()) + .appendSquareAndMultiply(data.snm) + .appendBitNum(data.getBitNum()); } builder.setStamp(stamp); From 8bac499c51971f5f92c8dcc92dda577725159df8 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 12 Apr 2023 13:57:28 +1000 Subject: [PATCH 10/31] wip Signed-off-by: Sally MacFarlane --- .../zktracer/module/alu/mul/MulData.java | 227 +++++++++--------- .../zktracer/module/alu/mul/MulTrace.java | 7 +- .../zktracer/module/alu/mul/MulTracer.java | 36 +-- 3 files changed, 135 insertions(+), 135 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 77d8b47299..6cb07e5653 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -1,138 +1,139 @@ package net.consensys.linea.zktracer.module.alu.mul; +import java.math.BigInteger; + import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.bytes.BytesBaseTheta; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; -import java.math.BigInteger; - @SuppressWarnings("UnusedVariable") public class MulData { - final OpCode opCode; - final boolean tinyBase; - final boolean tinyExponent; - - final UInt256 resAcc; // accumulator which converges in a series of "square and multiply"'s - final UInt256 expAcc; // accumulator for the doubles and adds of the exponent, resets at some point - - final BytesBaseTheta aBytes; - final BytesBaseTheta bBytes; - BytesBaseTheta cBytes; - BytesBaseTheta hBytes; + final OpCode opCode; + final boolean tinyBase; + final boolean tinyExponent; + + final UInt256 resAcc; // accumulator which converges in a series of "square and multiply"'s + final UInt256 + expAcc; // accumulator for the doubles and adds of the exponent, resets at some point + + final BytesBaseTheta aBytes; + final BytesBaseTheta bBytes; + BytesBaseTheta cBytes; + BytesBaseTheta hBytes; + boolean snm = false; + int index; + boolean[] bits; + String exponentBits; + + Res res; + + public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { + + this.opCode = opCode; + this.aBytes = new BytesBaseTheta(arg1); + this.bBytes = new BytesBaseTheta(arg2); + + // TODO what should these be initialized to + this.cBytes = null; + this.hBytes = null; boolean snm = false; - int index; - boolean[] bits; - String exponentBits; - - Res res; - public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { - - this.opCode = opCode; - this.aBytes = new BytesBaseTheta(arg1); - this.bBytes = new BytesBaseTheta(arg2); - - // TODO what should these be initialized to - this.cBytes = null; - this.hBytes = null; - boolean snm = false; - this.resAcc = UInt256.MIN_VALUE; - this.expAcc = UInt256.MIN_VALUE; - - this.res = Res.create(opCode, arg1, arg2); // TODO can we get this from the EVM - - - final UInt256 arg1Int = UInt256.fromBytes(arg1); - final UInt256 arg2Int = UInt256.fromBytes(arg2); - final BigInteger arg1BigInt = arg1Int.toUnsignedBigInteger(); - final BigInteger arg2BigInt = arg2Int.toUnsignedBigInteger(); - - this.tinyBase = isTiny(arg1BigInt); - this.tinyExponent = isTiny(arg2BigInt); - - final Regime regime = getRegime(opCode); - System.out.println(regime); - switch (regime) { - case TRIVIAL_MUL: - break; - case NON_TRIVIAL_MUL: - cBytes = new BytesBaseTheta(res); - break; - case EXPONENT_ZERO_RESULT: - setArraysForZeroResultCase(); - break; - case EXPONENT_NON_ZERO_RESULT: - this.exponentBits = arg2.toBigInteger().toString(); - snm = false; - break; - case IOTA: - throw new RuntimeException("alu/mul regime was never set"); - } + this.resAcc = UInt256.MIN_VALUE; + this.expAcc = UInt256.MIN_VALUE; + + this.res = Res.create(opCode, arg1, arg2); // TODO can we get this from the EVM + + final UInt256 arg1Int = UInt256.fromBytes(arg1); + final UInt256 arg2Int = UInt256.fromBytes(arg2); + final BigInteger arg1BigInt = arg1Int.toUnsignedBigInteger(); + final BigInteger arg2BigInt = arg2Int.toUnsignedBigInteger(); + + this.tinyBase = isTiny(arg1BigInt); + this.tinyExponent = isTiny(arg2BigInt); + + final Regime regime = getRegime(opCode); + System.out.println(regime); + switch (regime) { + case TRIVIAL_MUL: + break; + case NON_TRIVIAL_MUL: + cBytes = new BytesBaseTheta(res); + break; + case EXPONENT_ZERO_RESULT: + setArraysForZeroResultCase(); + break; + case EXPONENT_NON_ZERO_RESULT: + this.exponentBits = arg2.toBigInteger().toString(); + snm = false; + break; + case IOTA: + throw new RuntimeException("alu/mul regime was never set"); } + } - private void setArraysForZeroResultCase() { - // TODO - } + private void setArraysForZeroResultCase() { + // TODO + } - public boolean exponentBit() { - return '1' == exponentBits.charAt(index); - } + public boolean exponentBit() { + return '1' == exponentBits.charAt(index); + } - public boolean exponentSource() { - return this.index + 128 >= exponentBits.length(); - } + public boolean exponentSource() { + return this.index + 128 >= exponentBits.length(); + } - private boolean largeExponent() { - return exponentBits.length() > 128; - } + private boolean largeExponent() { + return exponentBits.length() > 128; + } - private enum Regime { - IOTA, - TRIVIAL_MUL, - NON_TRIVIAL_MUL, - EXPONENT_ZERO_RESULT, - EXPONENT_NON_ZERO_RESULT - } + private enum Regime { + IOTA, + TRIVIAL_MUL, + NON_TRIVIAL_MUL, + EXPONENT_ZERO_RESULT, + EXPONENT_NON_ZERO_RESULT + } - public boolean isOneLineInstruction() { - return tinyBase || tinyExponent; - } + public boolean isOneLineInstruction() { + return tinyBase || tinyExponent; + } - private Regime getRegime( - final OpCode opCode) { + private Regime getRegime(final OpCode opCode) { - if (isOneLineInstruction()) return Regime.TRIVIAL_MUL; + if (isOneLineInstruction()) return Regime.TRIVIAL_MUL; - if (OpCode.MUL.equals(opCode)) { - return Regime.NON_TRIVIAL_MUL; - } - - if (OpCode.EXP.equals(opCode)) { - if (res.isZero()) { - return Regime.EXPONENT_ZERO_RESULT; - } else { - return Regime.EXPONENT_NON_ZERO_RESULT; - } - } - return Regime.IOTA; - } - public static boolean isTiny(BigInteger arg) { - return arg.compareTo(BigInteger.valueOf(1)) <= 0; + if (OpCode.MUL.equals(opCode)) { + return Regime.NON_TRIVIAL_MUL; } - public int getBitNum() { - return bitNum(index, exponentBits.length()); + if (OpCode.EXP.equals(opCode)) { + if (res.isZero()) { + return Regime.EXPONENT_ZERO_RESULT; + } else { + return Regime.EXPONENT_NON_ZERO_RESULT; + } } - - private int bitNum(int i, int length) { - if (length <= 128) { - return i; - } else { - if (i+128 < length) { - return i; - } else { - return i + 128 - length; - } - } + return Regime.IOTA; + } + + public static boolean isTiny(BigInteger arg) { + return arg.compareTo(BigInteger.valueOf(1)) <= 0; + } + + public int getBitNum() { + return bitNum(index, exponentBits.length()); + } + + private int bitNum(int i, int length) { + if (length <= 128) { + return i; + } else { + if (i + 128 < length) { + return i; + } else { + return i + 128 - length; + } } + } } diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java index 1bc5ceeb97..9677ca8487 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java @@ -185,8 +185,7 @@ public static class Builder { private final List accH0 = new ArrayList<>(); private final List exponentBit = new ArrayList<>(); - // mul.Trace.PushLoBytes(EXPONENT_BIT_ACCUMULATOR.Name(), md.expAcc) // TODO: true ? risky :D - private final List exponentBitAccumulator = new ArrayList<>(); + private final List exponentBitAccumulator = new ArrayList<>(); private final List exponentBitSource = new ArrayList<>(); private final List squareAndMultiply = new ArrayList<>(); private final List bitNum = new ArrayList<>(); @@ -430,7 +429,7 @@ public Builder appendExponentBit(final Boolean b) { return this; } - public Builder appendExponentBitAcc(final UnsignedByte b) { + public Builder appendExponentBitAcc(final BigInteger b) { exponentBitAccumulator.add(b); return this; } @@ -445,7 +444,7 @@ public Builder appendSquareAndMultiply(final Boolean b) { return this; } - public Builder appendExponentBitNum(final Integer b) { + public Builder appendBitNum(final Integer b) { bitNum.add(b); return this; } diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java index d043c0f0e8..edba74424b 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java @@ -108,33 +108,33 @@ public Object trace(MessageFrame frame) { .appendByteC1(UnsignedByte.of(data.cBytes.get(1, i))) .appendByteC0(UnsignedByte.of(data.cBytes.get(0, i))); builder - .appendAccB3(Bytes.of(data.cBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB2(Bytes.of(data.cBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB1(Bytes.of(data.cBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB0(Bytes.of(data.cBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + .appendAccB3(Bytes.of(data.cBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB2(Bytes.of(data.cBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB1(Bytes.of(data.cBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB0(Bytes.of(data.cBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); builder - .appendByteH3(UnsignedByte.of(data.hBytes.get(3, i))) - .appendByteH2(UnsignedByte.of(data.hBytes.get(2, i))) - .appendByteH1(UnsignedByte.of(data.hBytes.get(1, i))) - .appendByteH0(UnsignedByte.of(data.hBytes.get(0, i))); + .appendByteH3(UnsignedByte.of(data.hBytes.get(3, i))) + .appendByteH2(UnsignedByte.of(data.hBytes.get(2, i))) + .appendByteH1(UnsignedByte.of(data.hBytes.get(1, i))) + .appendByteH0(UnsignedByte.of(data.hBytes.get(0, i))); builder - .appendAccB3(Bytes.of(data.hBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB2(Bytes.of(data.hBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB1(Bytes.of(data.hBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB0(Bytes.of(data.hBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); - builder.appendExponentBit(data.exponentBit()) - .appendExponentBitAcc(data.expAcc) - .appendExponentBitSource(data.exponentSource()) - .appendSquareAndMultiply(data.snm) - .appendBitNum(data.getBitNum()); + .appendAccB3(Bytes.of(data.hBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB2(Bytes.of(data.hBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB1(Bytes.of(data.hBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB0(Bytes.of(data.hBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + builder + .appendExponentBit(data.exponentBit()) + .appendExponentBitAcc(data.expAcc.toUnsignedBigInteger()) + .appendExponentBitSource(data.exponentSource()) + .appendSquareAndMultiply(data.snm) + .appendBitNum(data.getBitNum()); } builder.setStamp(stamp); return builder.build(); } - private int maxCt(final boolean isOneLineInstruction) { return isOneLineInstruction ? 1 : MMEDIUM; } From b5834d5d47693884ff05931f32d41b305eb2246e Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 12 Apr 2023 13:57:43 +1000 Subject: [PATCH 11/31] twoAdicity Signed-off-by: Sally MacFarlane --- .../zktracer/module/alu/mul/MulData.java | 24 +++++++++++++++++-- .../zktracer/module/alu/mul/MulTrace.java | 2 +- .../zktracer/module/alu/mul/MulUtilsTest.java | 9 +++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 6cb07e5653..39732ab131 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -83,9 +83,29 @@ public boolean exponentSource() { return this.index + 128 >= exponentBits.length(); } - private boolean largeExponent() { - return exponentBits.length() > 128; + public static int twoAdicity(final UInt256 x) { + + if (x.isZero()) { + // panic("twoAdicity was called on zero") + return 256; + } + + String baseStringBase2 = x.toBigInteger().toString(2); + + for (int i = 0; i < baseStringBase2.length(); i++) { + int j = baseStringBase2.length() - i - 1; + char zeroAscii = '0'; + if (baseStringBase2.charAt(j) != zeroAscii) { + return i; + } + } + + return 0; } + // + // private boolean largeExponent() { + // return exponentBits.length() > 128; + // } private enum Regime { IOTA, diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java index 9677ca8487..c13fa5a91e 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java @@ -120,7 +120,7 @@ public record Trace( @JsonProperty("BYTE_H_3") List BYTE_H_3, @JsonProperty("COUNTER") List COUNTER, @JsonProperty("EXPONENT_BIT") List EXPONENT_BIT, - @JsonProperty("EXPONENT_BIT_ACCUMULATOR") List EXPONENT_BIT_ACCUMULATOR, + @JsonProperty("EXPONENT_BIT_ACCUMULATOR") List EXPONENT_BIT_ACCUMULATOR, @JsonProperty("EXPONENT_BIT_SOURCE") List EXPONENT_BIT_SOURCE, @JsonProperty("INST") List INST, @JsonProperty("MUL_STAMP") List MUL_STAMP, diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java index e5f4ebb985..379b51afdf 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java @@ -4,6 +4,7 @@ import java.math.BigInteger; +import org.apache.tuweni.units.bigints.UInt256; import org.junit.jupiter.api.Test; public class MulUtilsTest { @@ -14,4 +15,12 @@ public void isTiny() { assertThat(MulData.isTiny(BigInteger.TWO)).isFalse(); assertThat(MulData.isTiny(BigInteger.TEN)).isFalse(); } + + @Test + public void twoAdicity() { + assertThat(MulData.twoAdicity(UInt256.MIN_VALUE)).isEqualTo(256); + // TODO no idea what these should be + // assertThat(MulData.twoAdicity(UInt256.MAX_VALUE)).isEqualTo(0); + // assertThat(MulData.twoAdicity(UInt256.valueOf(1))).isEqualTo(0); + } } From 1888d85b2515467476f3a597f265acd044265785 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 12 Apr 2023 14:08:53 +1000 Subject: [PATCH 12/31] fixed getting BytesTheta from Res hi/lo Signed-off-by: Sally MacFarlane --- .../java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java b/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java index b905d4b4a6..c2f3a3566c 100644 --- a/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java +++ b/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java @@ -29,7 +29,7 @@ public BytesBaseTheta(final Res res) { System.arraycopy(argBytesHi, 8 * k, bytes[3 - k], 0, 8); } for (int k = 2; k < 4; k++) { - System.arraycopy(argBytesLo, 8 * k, bytes[3 - k], 0, 8); + System.arraycopy(argBytesLo, 8 * (k-2), bytes[3 - k], 0, 8); } } From 48fac12c2e7014df3385113690b7b36d24191fa2 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 12 Apr 2023 15:33:49 +1000 Subject: [PATCH 13/31] working on setHsAndBits Signed-off-by: Sally MacFarlane --- .../linea/zktracer/bytes/BytesBaseTheta.java | 22 +++ .../zktracer/module/alu/mul/MulData.java | 129 ++++++++++++++++-- 2 files changed, 141 insertions(+), 10 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java b/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java index c2f3a3566c..20802e39f2 100644 --- a/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java +++ b/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java @@ -1,5 +1,6 @@ package net.consensys.linea.zktracer.bytes; +import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; @@ -33,6 +34,23 @@ public BytesBaseTheta(final Res res) { } } + public void set(final BigInteger bigInteger) { + // TODO how to get from BigInteger to bytes + bigInteger.toByteArray(); + + } + public void set(final int i, final BigInteger bigInteger) { + // TODO handle underflow + byte[] bigIntByteArray = bigInteger.toByteArray(); + System.arraycopy(bigIntByteArray, 0, bytes[i], 0, 8); + + } + public void set(final int i, final byte[] chunk) { + // TODO handle underflow + System.arraycopy(chunk, 0, bytes[i], 0, 8); + + } + // TODO can Res become Pair as below public Pair getHiLo() { byte[] hiBytes = new byte[16]; @@ -47,6 +65,10 @@ public Pair getHiLo() { return new Pair<>(hiBytes, loBytes); } + public byte[] getChunk(final int i) { + return bytes[i]; + } + public byte get(final int i, final int j) { return bytes[i][j]; } diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 39732ab131..e72053efc9 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -1,21 +1,24 @@ package net.consensys.linea.zktracer.module.alu.mul; +import java.lang.reflect.Array; import java.math.BigInteger; import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.bytes.BytesBaseTheta; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; +import org.apache.tuweni.units.bigints.UInt64; @SuppressWarnings("UnusedVariable") public class MulData { final OpCode opCode; + final Bytes32 arg1; + final Bytes32 arg2; final boolean tinyBase; final boolean tinyExponent; - final UInt256 resAcc; // accumulator which converges in a series of "square and multiply"'s - final UInt256 - expAcc; // accumulator for the doubles and adds of the exponent, resets at some point + BigInteger resAcc; // accumulator which converges in a series of "square and multiply"'s + UInt256 expAcc; // accumulator for doubles and adds of the exponent, resets at some point final BytesBaseTheta aBytes; final BytesBaseTheta bBytes; @@ -31,22 +34,20 @@ public class MulData { public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { this.opCode = opCode; + this.arg1 = arg1; + this.arg2 = arg2; this.aBytes = new BytesBaseTheta(arg1); this.bBytes = new BytesBaseTheta(arg2); - // TODO what should these be initialized to + // TODO what should these be initialized to (or is this not needed) this.cBytes = null; this.hBytes = null; - boolean snm = false; - this.resAcc = UInt256.MIN_VALUE; this.expAcc = UInt256.MIN_VALUE; this.res = Res.create(opCode, arg1, arg2); // TODO can we get this from the EVM - final UInt256 arg1Int = UInt256.fromBytes(arg1); - final UInt256 arg2Int = UInt256.fromBytes(arg2); - final BigInteger arg1BigInt = arg1Int.toUnsignedBigInteger(); - final BigInteger arg2BigInt = arg2Int.toUnsignedBigInteger(); + final BigInteger arg1BigInt = UInt256.fromBytes(arg1).toUnsignedBigInteger(); + final BigInteger arg2BigInt = UInt256.fromBytes(arg2).toUnsignedBigInteger(); this.tinyBase = isTiny(arg1BigInt); this.tinyExponent = isTiny(arg2BigInt); @@ -156,4 +157,112 @@ private int bitNum(int i, int length) { } } } + + private void update() { + + final BigInteger arg1BigInt = UInt256.fromBytes(arg1).toUnsignedBigInteger(); + final BigInteger arg2BigInt = UInt256.fromBytes(arg2).toUnsignedBigInteger(); + if (!snm) { + // squaring + setHsAndBits(resAcc, resAcc); + expAcc = expAcc.add(expAcc); + resAcc = resAcc.multiply(resAcc); + } else { + // multiplying by base + setHsAndBits(arg1BigInt, resAcc); + expAcc = expAcc.add(UInt256.ONE); + resAcc = arg1BigInt.multiply(resAcc); + } + cBytes.set(resAcc); // TODO how to get from BigInteger to Bytes32 + } + + private void setHsAndBits(BigInteger a, BigInteger b) { + + // TODO set hBytes and bits[] + BytesBaseTheta aBaseTheta, bBaseTheta, sumBaseTheta ; + + aBaseTheta.set(a); + bBaseTheta.set(b); + + UInt256[] aBaseThetaInts = (UInt256[]) Array.newInstance(UInt256.class, 4); + UInt256[] bBaseThetaInts = (UInt256[]) Array.newInstance(UInt256.class, 4); + + for (int i = 0; i < 4; i++) { + aBaseThetaInts[i] = UInt256.ZERO; + bBaseThetaInts[i] = UInt256.ZERO; + aBaseThetaInts[i].setBytes(aBaseTheta.getChunk(i)); + bBaseThetaInts[i].setBytes(bBaseTheta.getChunk(i)); + } + + UInt256 sum, prod; + prod = aBaseThetaInts[1].multiply(bBaseThetaInts[0]); + sum = UInt256.MIN_VALUE.add(prod); // sum := a1 * b0 + prod = aBaseThetaInts[0].multiply(bBaseThetaInts[1]); + sum = sum.add(prod); // sum += a0 * b1 + + sumBaseTheta.set(sum.toBigInteger()); + hBytes.set(0, sumBaseTheta.getChunk(0)); + hBytes.set(1, sumBaseTheta.getChunk(1)); + int alpha = getOverflow(sum, 1, "alpha OOB"); + + prod = aBaseThetaInts[3].multiply(bBaseThetaInts[0]); + sum = UInt256.MIN_VALUE.add(prod); // sum := a3 * b0 + prod = aBaseThetaInts[2].multiply(bBaseThetaInts[1]); + sum = sum.add(prod); // sum += a2 * b1 + prod = aBaseThetaInts[1].multiply(bBaseThetaInts[2]); + sum = sum.add(prod); // sum += a1 * b2 + prod = aBaseThetaInts[0].multiply(bBaseThetaInts[3]); + sum = sum.add(prod); // sum += a0 * b3 + + sumBaseTheta.set(sum.toBigInteger()); + hBytes.set(2, sumBaseTheta.getChunk(0)); + hBytes.set(3, sumBaseTheta.getChunk(1)); + int beta = getOverflow(sum, 3, "beta OOB"); + + prod = aBaseThetaInts[0].multiply(bBaseThetaInts[0]); + sum = UInt256.MIN_VALUE.add(prod); // sum := a0 * b0 + prod = hBytes.getChunk(0).shiftLeft(64); + sum = sum.add(prod);// sum += (h0 << 64) +// sum.Add(sum, prod.Lsh(prod.SetBytes(hs[0][:]), 64)) // sum += (h0 << 64) + + int eta = getOverflow(sum, 1, "eta OOB"); + + sum = UInt256.valueOf(eta); // sum := eta + sum.Add(sum, prod.SetBytes(hs[1][:])) // sum += h1 + sum.Add(sum, prod.Lsh(prod.SetUint64(alpha), 64)) ; // sum += (alpha << 64) + prod = aBaseThetaInts[2].multiply(bBaseThetaInts[0]); + sum = sum.add(prod); // sum += a2 * b0 + prod = aBaseThetaInts[1].multiply(bBaseThetaInts[1]); + sum = sum.add(prod); // sum += a1 * b1 + prod = aBaseThetaInts[0].multiply(bBaseThetaInts[2]); + sum = sum.add(prod); // sum += a0 * b2 + sum.Add(sum, prod.Lsh(prod.SetBytes(hs[2][:]), 64)) // sum += (h2 << 64) + + int mu = getOverflow(sum, 3, "mu OOB"); + + bits[2] = getBit(alpha, 0); + bits[3] = getBit(beta, 0); + bits[4] = getBit(beta, 1); + bits[5] = getBit(eta, 0); + bits[6] = getBit(mu, 0); + bits[7] = getBit(mu, 1); + + return; + } + + public static int getOverflow(final UInt256 arg, final int maxVal, final String err) { + UInt256 shiftRight = arg.shiftRight( 128); + if (shiftRight.toBigInteger().compareTo (UInt64.MAX_VALUE.toBigInteger()) > 0) { + throw new RuntimeException("getOverflow expects a small high part"); + } + int overflow = shiftRight.toInt(); + if (overflow > maxVal) { + throw new RuntimeException(err); + } + return overflow; + } + // GetBit returns true iff the k'th bit of x is 1 + private boolean getBit(int x, int k) { + return (x>>k)%2 == 1; + } } From f17a19d83ba53f85eefa6a2bf2d4659dbae1f699 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Thu, 13 Apr 2023 14:17:23 +1000 Subject: [PATCH 14/31] WIP Signed-off-by: Sally MacFarlane --- .../linea/zktracer/bytes/Bytes16.java | 2 +- .../linea/zktracer/bytes/BytesBaseTheta.java | 11 +- .../zktracer/module/alu/mul/MulData.java | 173 +++++++++++++---- .../zktracer/module/alu/mul/MulTracer.java | 183 ++++++++++-------- 4 files changed, 244 insertions(+), 125 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/bytes/Bytes16.java b/src/main/java/net/consensys/linea/zktracer/bytes/Bytes16.java index 3a4f7cfcad..9ff1ffc193 100644 --- a/src/main/java/net/consensys/linea/zktracer/bytes/Bytes16.java +++ b/src/main/java/net/consensys/linea/zktracer/bytes/Bytes16.java @@ -132,7 +132,7 @@ static Bytes16 leftPad(Bytes value) { * Right pad a {@link Bytes} value with zero bytes to create a {@link Bytes16}. * * @param value The bytes value pad. - * @return A {@link Bytes16} that exposes the rightw-padded bytes of {@code value}. + * @return A {@link Bytes16} that exposes the right-padded bytes of {@code value}. * @throws IllegalArgumentException if {@code value.size() > 16}. */ static Bytes16 rightPad(Bytes value) { diff --git a/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java b/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java index 20802e39f2..32ee90a905 100644 --- a/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java +++ b/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java @@ -30,25 +30,24 @@ public BytesBaseTheta(final Res res) { System.arraycopy(argBytesHi, 8 * k, bytes[3 - k], 0, 8); } for (int k = 2; k < 4; k++) { - System.arraycopy(argBytesLo, 8 * (k-2), bytes[3 - k], 0, 8); + System.arraycopy(argBytesLo, 8 * (k - 2), bytes[3 - k], 0, 8); } } public void set(final BigInteger bigInteger) { // TODO how to get from BigInteger to bytes bigInteger.toByteArray(); - } + public void set(final int i, final BigInteger bigInteger) { // TODO handle underflow byte[] bigIntByteArray = bigInteger.toByteArray(); System.arraycopy(bigIntByteArray, 0, bytes[i], 0, 8); - } + public void set(final int i, final byte[] chunk) { // TODO handle underflow System.arraycopy(chunk, 0, bytes[i], 0, 8); - } // TODO can Res become Pair as below @@ -76,6 +75,10 @@ public byte get(final int i, final int j) { public byte[] getRange(final int i, final int start, final int end) { return Arrays.copyOfRange(bytes[i], start, end); } + + public void set(int i, int j, byte b) { + bytes[i][j] = b; + } } @SuppressWarnings("UnusedVariable") diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index e72053efc9..f3fcb700f2 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -4,16 +4,25 @@ import java.math.BigInteger; import net.consensys.linea.zktracer.OpCode; +import net.consensys.linea.zktracer.bytes.Bytes16; import net.consensys.linea.zktracer.bytes.BytesBaseTheta; +import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; import org.apache.tuweni.units.bigints.UInt64; @SuppressWarnings("UnusedVariable") public class MulData { + private static final int MMEDIUM = 8; final OpCode opCode; final Bytes32 arg1; final Bytes32 arg2; + + final Bytes16 arg1Hi; + final Bytes16 arg1Lo; + final Bytes16 arg2Hi; + final Bytes16 arg2Lo; + final boolean tinyBase; final boolean tinyExponent; @@ -39,6 +48,11 @@ public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { this.aBytes = new BytesBaseTheta(arg1); this.bBytes = new BytesBaseTheta(arg2); + arg1Hi = Bytes16.wrap(arg1.slice(0, 16)); + arg1Lo = Bytes16.wrap(arg1.slice(16)); + arg2Hi = Bytes16.wrap(arg2.slice(0, 16)); + arg2Lo = Bytes16.wrap(arg2.slice(16)); + // TODO what should these be initialized to (or is this not needed) this.cBytes = null; this.hBytes = null; @@ -52,7 +66,7 @@ public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { this.tinyBase = isTiny(arg1BigInt); this.tinyExponent = isTiny(arg2BigInt); - final Regime regime = getRegime(opCode); + final Regime regime = getRegime(); System.out.println(regime); switch (regime) { case TRIVIAL_MUL: @@ -73,7 +87,49 @@ public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { } private void setArraysForZeroResultCase() { - // TODO + int nu = twoAdicity(arg1); + + if (nu >= 128) { + return; + } + + byte[] ones = Bytes.repeat((byte) 1, 8).toArray(); + byte[] bytes; + + if (128 > nu && nu >= 64) { + bytes = aBytes.getChunk(1); + } else { + for (int i = 0; i < 8; i++) { + cBytes.set(0, ones); + } + bytes = aBytes.getChunk(0); + } + int nuQuo = (nu / 8) % 8; + int nuRem = nu % 8; + byte pivotByte = bytes[7 - nuQuo]; + + for (int i = 0; i < 8; i++) { + cBytes.set(1, i, pivotByte); + cBytes.set(2, i, boolToByte(i > 7 - nuRem)); + cBytes.set(3, i, boolToByte(i > 7 - nuQuo)); + hBytes.set(2, i, callFunc(i, 7 - nuRem)); + hBytes.set(3, i, callFunc(i, 7 - nuQuo)); + } + // TODO the rest + } + + private byte callFunc(final int x, final int k) { + if (x < k) { + return 0; + } + return (byte) (x - k); + } + + private byte boolToByte(boolean b) { + if (b) { + return 1; + } + return 0; } public boolean exponentBit() { @@ -84,7 +140,7 @@ public boolean exponentSource() { return this.index + 128 >= exponentBits.length(); } - public static int twoAdicity(final UInt256 x) { + public static int twoAdicity(final Bytes32 x) { if (x.isZero()) { // panic("twoAdicity was called on zero") @@ -108,7 +164,7 @@ public static int twoAdicity(final UInt256 x) { // return exponentBits.length() > 128; // } - private enum Regime { + public enum Regime { IOTA, TRIVIAL_MUL, NON_TRIVIAL_MUL, @@ -120,7 +176,7 @@ public boolean isOneLineInstruction() { return tinyBase || tinyExponent; } - private Regime getRegime(final OpCode opCode) { + public Regime getRegime() { if (isOneLineInstruction()) return Regime.TRIVIAL_MUL; @@ -142,6 +198,29 @@ public static boolean isTiny(BigInteger arg) { return arg.compareTo(BigInteger.valueOf(1)) <= 0; } + public boolean carryOn() { + + // first round is special + if (index == 0 && !snm) { + snm = true; + resAcc = BigInteger.ONE; // TODO assuming this is what SetOne() does + cBytes.set(arg1.toBigInteger()); + return true; + } + + if (snm == exponentBit()) { + hiToLoExponentBitAccumulatorReset(); + index++; + snm = false; + if (index == exponentBits.length()) { + return false; + } + } else { + snm = true; + } + return true; + } + public int getBitNum() { return bitNum(index, exponentBits.length()); } @@ -158,7 +237,7 @@ private int bitNum(int i, int length) { } } - private void update() { + public void update() { final BigInteger arg1BigInt = UInt256.fromBytes(arg1).toUnsignedBigInteger(); final BigInteger arg2BigInt = UInt256.fromBytes(arg2).toUnsignedBigInteger(); @@ -173,40 +252,39 @@ private void update() { expAcc = expAcc.add(UInt256.ONE); resAcc = arg1BigInt.multiply(resAcc); } - cBytes.set(resAcc); // TODO how to get from BigInteger to Bytes32 + cBytes.set(resAcc); } - private void setHsAndBits(BigInteger a, BigInteger b) { + public void setHsAndBits(BigInteger a, BigInteger b) { - // TODO set hBytes and bits[] - BytesBaseTheta aBaseTheta, bBaseTheta, sumBaseTheta ; + BytesBaseTheta aBaseTheta, bBaseTheta, sumBaseTheta; + aBaseTheta = new BytesBaseTheta(Bytes32.ZERO); + bBaseTheta = new BytesBaseTheta(Bytes32.ZERO); + sumBaseTheta = new BytesBaseTheta(Bytes32.ZERO); aBaseTheta.set(a); bBaseTheta.set(b); - UInt256[] aBaseThetaInts = (UInt256[]) Array.newInstance(UInt256.class, 4); - UInt256[] bBaseThetaInts = (UInt256[]) Array.newInstance(UInt256.class, 4); + BigInteger[] aBaseThetaInts = (BigInteger[]) Array.newInstance(UInt256.class, 4); + BigInteger[] bBaseThetaInts = (BigInteger[]) Array.newInstance(UInt256.class, 4); for (int i = 0; i < 4; i++) { - aBaseThetaInts[i] = UInt256.ZERO; - bBaseThetaInts[i] = UInt256.ZERO; - aBaseThetaInts[i].setBytes(aBaseTheta.getChunk(i)); - bBaseThetaInts[i].setBytes(bBaseTheta.getChunk(i)); + aBaseThetaInts[i] = Bytes.of(aBaseTheta.getChunk(i)).toBigInteger(); + bBaseThetaInts[i] = Bytes.of(bBaseTheta.getChunk(i)).toBigInteger(); } - UInt256 sum, prod; + BigInteger sum, prod; prod = aBaseThetaInts[1].multiply(bBaseThetaInts[0]); - sum = UInt256.MIN_VALUE.add(prod); // sum := a1 * b0 + sum = prod; // sum := a1 * b0 prod = aBaseThetaInts[0].multiply(bBaseThetaInts[1]); sum = sum.add(prod); // sum += a0 * b1 - sumBaseTheta.set(sum.toBigInteger()); + sumBaseTheta.set(sum); hBytes.set(0, sumBaseTheta.getChunk(0)); hBytes.set(1, sumBaseTheta.getChunk(1)); int alpha = getOverflow(sum, 1, "alpha OOB"); - prod = aBaseThetaInts[3].multiply(bBaseThetaInts[0]); - sum = UInt256.MIN_VALUE.add(prod); // sum := a3 * b0 + sum = aBaseThetaInts[3].multiply(bBaseThetaInts[0]); // sum := a3 * b0 prod = aBaseThetaInts[2].multiply(bBaseThetaInts[1]); sum = sum.add(prod); // sum += a2 * b1 prod = aBaseThetaInts[1].multiply(bBaseThetaInts[2]); @@ -214,29 +292,26 @@ private void setHsAndBits(BigInteger a, BigInteger b) { prod = aBaseThetaInts[0].multiply(bBaseThetaInts[3]); sum = sum.add(prod); // sum += a0 * b3 - sumBaseTheta.set(sum.toBigInteger()); + sumBaseTheta.set(sum); hBytes.set(2, sumBaseTheta.getChunk(0)); hBytes.set(3, sumBaseTheta.getChunk(1)); int beta = getOverflow(sum, 3, "beta OOB"); - prod = aBaseThetaInts[0].multiply(bBaseThetaInts[0]); - sum = UInt256.MIN_VALUE.add(prod); // sum := a0 * b0 - prod = hBytes.getChunk(0).shiftLeft(64); - sum = sum.add(prod);// sum += (h0 << 64) -// sum.Add(sum, prod.Lsh(prod.SetBytes(hs[0][:]), 64)) // sum += (h0 << 64) + sum = aBaseThetaInts[0].multiply(bBaseThetaInts[0]); // sum := a0 * b0 + sum = sum.add(shiftLeft64(hBytes.getChunk(0))); // sum += (h0 << 64) int eta = getOverflow(sum, 1, "eta OOB"); - sum = UInt256.valueOf(eta); // sum := eta - sum.Add(sum, prod.SetBytes(hs[1][:])) // sum += h1 - sum.Add(sum, prod.Lsh(prod.SetUint64(alpha), 64)) ; // sum += (alpha << 64) + sum = BigInteger.valueOf(eta); // sum := eta + sum = sum.add(Bytes16.wrap(hBytes.getChunk(1)).toUnsignedBigInteger()); // sum += h1 + sum = sum.add(BigInteger.valueOf(alpha).shiftLeft(64)); // sum += (alpha << 64) prod = aBaseThetaInts[2].multiply(bBaseThetaInts[0]); sum = sum.add(prod); // sum += a2 * b0 prod = aBaseThetaInts[1].multiply(bBaseThetaInts[1]); sum = sum.add(prod); // sum += a1 * b1 prod = aBaseThetaInts[0].multiply(bBaseThetaInts[2]); sum = sum.add(prod); // sum += a0 * b2 - sum.Add(sum, prod.Lsh(prod.SetBytes(hs[2][:]), 64)) // sum += (h2 << 64) + sum = sum.add(shiftLeft64(hBytes.getChunk(2))); // sum += (h2 << 64) int mu = getOverflow(sum, 3, "mu OOB"); @@ -250,19 +325,45 @@ private void setHsAndBits(BigInteger a, BigInteger b) { return; } - public static int getOverflow(final UInt256 arg, final int maxVal, final String err) { - UInt256 shiftRight = arg.shiftRight( 128); - if (shiftRight.toBigInteger().compareTo (UInt64.MAX_VALUE.toBigInteger()) > 0) { + private BigInteger shiftLeft64(byte[] b16) { + final Bytes16 copy = Bytes16.wrap(b16).copy(); + return copy.shiftLeft(64).toUnsignedBigInteger(); + } + + // hiToLoExponentBitAccumulatorReset resets the exponent bit accumulator + // under the following conditions: + // - we are dealing with the high part of the exponent bits, i.e. md.exponentBit() = 0 + // - SQUARE_AND_MULTIPLY == EXPONENT_BIT + // - the exponent bit accumulator coincides with the high part of the exponent + private void hiToLoExponentBitAccumulatorReset() { + if (!exponentSource()) { + if (snm == exponentBit()) { // note: when called this is already assumed + Bytes32 arg2Copy = arg2.copy(); + if (arg2Copy.shiftRight(128).equals(expAcc)) { + expAcc = UInt256.MIN_VALUE; + } + } + } + } + + public static int getOverflow(final BigInteger arg, final int maxVal, final String err) { + BigInteger shiftRight = arg.shiftRight(128); + if (shiftRight.compareTo(UInt64.MAX_VALUE.toBigInteger()) > 0) { throw new RuntimeException("getOverflow expects a small high part"); } - int overflow = shiftRight.toInt(); + int overflow = shiftRight.intValue(); if (overflow > maxVal) { throw new RuntimeException(err); } return overflow; } + // GetBit returns true iff the k'th bit of x is 1 private boolean getBit(int x, int k) { - return (x>>k)%2 == 1; + return (x >> k) % 2 == 1; + } + + public int maxCt() { + return isOneLineInstruction() ? 1 : MMEDIUM; } } diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java index edba74424b..14d25631a6 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java @@ -19,14 +19,12 @@ import java.util.List; import net.consensys.linea.zktracer.OpCode; -import net.consensys.linea.zktracer.bytes.Bytes16; import net.consensys.linea.zktracer.bytes.UnsignedByte; import net.consensys.linea.zktracer.module.ModuleTracer; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; public class MulTracer implements ModuleTracer { - private static final int MMEDIUM = 8; private int stamp = 0; @@ -46,96 +44,113 @@ public Object trace(MessageFrame frame) { final Bytes32 arg1 = Bytes32.wrap(frame.getStackItem(0)); final Bytes32 arg2 = Bytes32.wrap(frame.getStackItem(1)); - final Bytes16 arg1Hi = Bytes16.wrap(arg1.slice(0, 16)); - final Bytes16 arg1Lo = Bytes16.wrap(arg1.slice(16)); - final Bytes16 arg2Hi = Bytes16.wrap(arg2.slice(0, 16)); - final Bytes16 arg2Lo = Bytes16.wrap(arg2.slice(16)); - final OpCode opCode = OpCode.of(frame.getCurrentOperation().getOpcode()); final MulData data = new MulData(opCode, arg1, arg2); final MulTrace.Trace.Builder builder = MulTrace.Trace.Builder.newInstance(); - final boolean isOneLineInstruction = data.isOneLineInstruction(); + final int maxCt = data.maxCt(); stamp++; - for (int i = 0; i < maxCt(isOneLineInstruction); i++) { - builder.appendStamp(stamp); - builder.appendCounter(i); - - builder - .appendOneLineInstruction(isOneLineInstruction) - .appendTinyBase(data.tinyBase) - .appendTinyExponent(data.tinyExponent) - .appendResultVanishes(data.res.isZero()); - - builder - .appendInst(UnsignedByte.of(opCode.value)) - .appendArg1Hi(arg1Hi.toUnsignedBigInteger()) - .appendArg1Lo(arg1Lo.toUnsignedBigInteger()) - .appendArg2Hi(arg2Hi.toUnsignedBigInteger()) - .appendArg2Lo(arg2Lo.toUnsignedBigInteger()); - - builder - .appendResHi(data.res.getResHi().toUnsignedBigInteger()) - .appendResLo(data.res.getResLo().toUnsignedBigInteger()); - - // builder.appendBits(bits.get(i)).appendCounter(i); // TODO - - builder - .appendByteA3(UnsignedByte.of(data.aBytes.get(3, i))) - .appendByteA2(UnsignedByte.of(data.aBytes.get(2, i))) - .appendByteA1(UnsignedByte.of(data.aBytes.get(1, i))) - .appendByteA0(UnsignedByte.of(data.aBytes.get(0, i))); - builder - .appendAccA3(Bytes.of(data.aBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccA2(Bytes.of(data.aBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccA1(Bytes.of(data.aBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccA0(Bytes.of(data.aBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); - - builder - .appendByteB3(UnsignedByte.of(data.bBytes.get(3, i))) - .appendByteB2(UnsignedByte.of(data.bBytes.get(2, i))) - .appendByteB1(UnsignedByte.of(data.bBytes.get(1, i))) - .appendByteB0(UnsignedByte.of(data.bBytes.get(0, i))); - builder - .appendAccB3(Bytes.of(data.bBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB2(Bytes.of(data.bBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB1(Bytes.of(data.bBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB0(Bytes.of(data.bBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); - builder - .appendByteC3(UnsignedByte.of(data.cBytes.get(3, i))) - .appendByteC2(UnsignedByte.of(data.cBytes.get(2, i))) - .appendByteC1(UnsignedByte.of(data.cBytes.get(1, i))) - .appendByteC0(UnsignedByte.of(data.cBytes.get(0, i))); - builder - .appendAccB3(Bytes.of(data.cBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB2(Bytes.of(data.cBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB1(Bytes.of(data.cBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB0(Bytes.of(data.cBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); - - builder - .appendByteH3(UnsignedByte.of(data.hBytes.get(3, i))) - .appendByteH2(UnsignedByte.of(data.hBytes.get(2, i))) - .appendByteH1(UnsignedByte.of(data.hBytes.get(1, i))) - .appendByteH0(UnsignedByte.of(data.hBytes.get(0, i))); - builder - .appendAccB3(Bytes.of(data.hBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB2(Bytes.of(data.hBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB1(Bytes.of(data.hBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB0(Bytes.of(data.hBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); - builder - .appendExponentBit(data.exponentBit()) - .appendExponentBitAcc(data.expAcc.toUnsignedBigInteger()) - .appendExponentBitSource(data.exponentSource()) - .appendSquareAndMultiply(data.snm) - .appendBitNum(data.getBitNum()); - } - builder.setStamp(stamp); - return builder.build(); + switch (data.getRegime()) { + case EXPONENT_ZERO_RESULT: + for (int ct = 0; ct < maxCt; ct++) { + trace(builder, data, ct); + } + return builder.build(); + + case EXPONENT_NON_ZERO_RESULT: + if (data.carryOn()) { + data.update(); + for (int ct = 0; ct < maxCt; ct++) { + trace(builder, data, ct); + } + } + return builder.build(); + + case TRIVIAL_MUL, NON_TRIVIAL_MUL: + data.setHsAndBits(arg1.toBigInteger(), arg2.toBigInteger()); + for (int ct = 0; ct < maxCt; ct++) { + trace(builder, data, ct); + } + return builder.build(); + + default: + throw new RuntimeException("regime not supported"); + } } - private int maxCt(final boolean isOneLineInstruction) { - return isOneLineInstruction ? 1 : MMEDIUM; + private void trace(final MulTrace.Trace.Builder builder, final MulData data, final int i) { + builder.appendStamp(stamp); + builder.appendCounter(i); + + builder + .appendOneLineInstruction(data.isOneLineInstruction()) + .appendTinyBase(data.tinyBase) + .appendTinyExponent(data.tinyExponent) + .appendResultVanishes(data.res.isZero()); + + builder + .appendInst(UnsignedByte.of(data.opCode.value)) + .appendArg1Hi(data.arg1Hi.toUnsignedBigInteger()) + .appendArg1Lo(data.arg1Lo.toUnsignedBigInteger()) + .appendArg2Hi(data.arg2Hi.toUnsignedBigInteger()) + .appendArg2Lo(data.arg2Lo.toUnsignedBigInteger()); + + builder + .appendResHi(data.res.getResHi().toUnsignedBigInteger()) + .appendResLo(data.res.getResLo().toUnsignedBigInteger()); + + builder.appendBits(data.bits[i]); + + builder + .appendByteA3(UnsignedByte.of(data.aBytes.get(3, i))) + .appendByteA2(UnsignedByte.of(data.aBytes.get(2, i))) + .appendByteA1(UnsignedByte.of(data.aBytes.get(1, i))) + .appendByteA0(UnsignedByte.of(data.aBytes.get(0, i))); + builder + .appendAccA3(Bytes.of(data.aBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) + .appendAccA2(Bytes.of(data.aBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) + .appendAccA1(Bytes.of(data.aBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) + .appendAccA0(Bytes.of(data.aBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + + builder + .appendByteB3(UnsignedByte.of(data.bBytes.get(3, i))) + .appendByteB2(UnsignedByte.of(data.bBytes.get(2, i))) + .appendByteB1(UnsignedByte.of(data.bBytes.get(1, i))) + .appendByteB0(UnsignedByte.of(data.bBytes.get(0, i))); + builder + .appendAccB3(Bytes.of(data.bBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB2(Bytes.of(data.bBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB1(Bytes.of(data.bBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB0(Bytes.of(data.bBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + builder + .appendByteC3(UnsignedByte.of(data.cBytes.get(3, i))) + .appendByteC2(UnsignedByte.of(data.cBytes.get(2, i))) + .appendByteC1(UnsignedByte.of(data.cBytes.get(1, i))) + .appendByteC0(UnsignedByte.of(data.cBytes.get(0, i))); + builder + .appendAccB3(Bytes.of(data.cBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB2(Bytes.of(data.cBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB1(Bytes.of(data.cBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB0(Bytes.of(data.cBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + + builder + .appendByteH3(UnsignedByte.of(data.hBytes.get(3, i))) + .appendByteH2(UnsignedByte.of(data.hBytes.get(2, i))) + .appendByteH1(UnsignedByte.of(data.hBytes.get(1, i))) + .appendByteH0(UnsignedByte.of(data.hBytes.get(0, i))); + builder + .appendAccB3(Bytes.of(data.hBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB2(Bytes.of(data.hBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB1(Bytes.of(data.hBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) + .appendAccB0(Bytes.of(data.hBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + builder + .appendExponentBit(data.exponentBit()) + .appendExponentBitAcc(data.expAcc.toUnsignedBigInteger()) + .appendExponentBitSource(data.exponentSource()) + .appendSquareAndMultiply(data.snm) + .appendBitNum(data.getBitNum()); + builder.setStamp(stamp); } } From e7055463b29c1bfb11a3845f023d07d0a3dad0b9 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Thu, 13 Apr 2023 17:20:39 +1000 Subject: [PATCH 15/31] got to the end of setArraysForZeroResultCase() Signed-off-by: Sally MacFarlane --- .../consensys/linea/zktracer/module/Util.java | 41 ++++++++++ .../zktracer/module/alu/mul/MulData.java | 80 ++++++++++++------- 2 files changed, 93 insertions(+), 28 deletions(-) create mode 100644 src/main/java/net/consensys/linea/zktracer/module/Util.java diff --git a/src/main/java/net/consensys/linea/zktracer/module/Util.java b/src/main/java/net/consensys/linea/zktracer/module/Util.java new file mode 100644 index 0000000000..6e9784d03a --- /dev/null +++ b/src/main/java/net/consensys/linea/zktracer/module/Util.java @@ -0,0 +1,41 @@ +package net.consensys.linea.zktracer.module; + +import java.math.BigInteger; + +import net.consensys.linea.zktracer.bytes.UnsignedByte; +import org.apache.tuweni.units.bigints.UInt64; + +public class Util { + + public static byte boolToByte(boolean b) { + if (b) { + return 1; + } + return 0; + } + + public static Boolean[] byteBits(final UnsignedByte b) { + final Boolean[] bits = new Boolean[8]; + for (int i = 0; i < 8; i++) { + bits[7 - i] = b.shiftRight(i).mod(2).toInteger() == 1; + } + return bits; + } + + public static int getOverflow(final BigInteger arg, final int maxVal, final String err) { + BigInteger shiftRight = arg.shiftRight(128); + if (shiftRight.compareTo(UInt64.MAX_VALUE.toBigInteger()) > 0) { + throw new RuntimeException("getOverflow expects a small high part"); + } + int overflow = shiftRight.intValue(); + if (overflow > maxVal) { + throw new RuntimeException(err); + } + return overflow; + } + + // GetBit returns true iff the k'th bit of x is 1 + public static boolean getBit(int x, int k) { + return (x >> k) % 2 == 1; + } +} diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index f3fcb700f2..76ee0daa41 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -1,15 +1,20 @@ package net.consensys.linea.zktracer.module.alu.mul; +import static net.consensys.linea.zktracer.module.Util.boolToByte; +import static net.consensys.linea.zktracer.module.Util.byteBits; +import static net.consensys.linea.zktracer.module.Util.getBit; +import static net.consensys.linea.zktracer.module.Util.getOverflow; + import java.lang.reflect.Array; import java.math.BigInteger; import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.bytes.Bytes16; import net.consensys.linea.zktracer.bytes.BytesBaseTheta; +import net.consensys.linea.zktracer.bytes.UnsignedByte; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; -import org.apache.tuweni.units.bigints.UInt64; @SuppressWarnings("UnusedVariable") public class MulData { @@ -35,7 +40,7 @@ public class MulData { BytesBaseTheta hBytes; boolean snm = false; int index; - boolean[] bits; + Boolean[] bits; String exponentBits; Res res; @@ -115,23 +120,59 @@ private void setArraysForZeroResultCase() { hBytes.set(2, i, callFunc(i, 7 - nuRem)); hBytes.set(3, i, callFunc(i, 7 - nuQuo)); } - // TODO the rest + + bits = byteBits(UnsignedByte.of(pivotByte)); + + int lowerBoundOnTwoAdicity = 8 * (int) (hBytes.get(3, 7)) + (int) (hBytes.get(2, 7)); + + if (nu >= 64) { + lowerBoundOnTwoAdicity += 64; + } + + // our lower bound should coincide with the 2-adicity + if (lowerBoundOnTwoAdicity != nu) { + String s = + String.format( + "2-adicity nu = %v != %v = lower bound on 2-adicity", nu, lowerBoundOnTwoAdicity); + throw new RuntimeException(s); + } + if (lowerBoundOnTwoAdicity == 0) { + throw new RuntimeException("lower bound on 2 adicity == 0 in the zero result case"); + } + + UInt256 twoFiftySix = UInt256.valueOf(256); + if (arg2.compareTo(twoFiftySix) >= 0) { + // arg2 = exponent >= 256 + hBytes.set(1, 6, (byte) ((lowerBoundOnTwoAdicity - 1) / 256)); + hBytes.set(1, 7, (byte) ((lowerBoundOnTwoAdicity - 1) % 256)); + } else { + // exponent < 256 + int exponent = arg2.toUnsignedBigInteger().intValue(); + int target = exponent * lowerBoundOnTwoAdicity - 256; + + if (target < 0) { + throw new RuntimeException("lower bound on 2-adicity is wrong"); + } + + if (target > 255 * (8 * 7 + 7 + 64)) { + throw new RuntimeException("something went awfully wrong ..."); + } + + final BytesBaseTheta thing = + new BytesBaseTheta(Bytes32.wrap(BigInteger.valueOf(target).toByteArray())); + hBytes.set(1, thing.getChunk(0)); + } + + return; } - private byte callFunc(final int x, final int k) { + public static byte callFunc(final int x, final int k) { if (x < k) { return 0; } return (byte) (x - k); } - private byte boolToByte(boolean b) { - if (b) { - return 1; - } - return 0; - } - public boolean exponentBit() { return '1' == exponentBits.charAt(index); } @@ -346,23 +387,6 @@ private void hiToLoExponentBitAccumulatorReset() { } } - public static int getOverflow(final BigInteger arg, final int maxVal, final String err) { - BigInteger shiftRight = arg.shiftRight(128); - if (shiftRight.compareTo(UInt64.MAX_VALUE.toBigInteger()) > 0) { - throw new RuntimeException("getOverflow expects a small high part"); - } - int overflow = shiftRight.intValue(); - if (overflow > maxVal) { - throw new RuntimeException(err); - } - return overflow; - } - - // GetBit returns true iff the k'th bit of x is 1 - private boolean getBit(int x, int k) { - return (x >> k) % 2 == 1; - } - public int maxCt() { return isOneLineInstruction() ? 1 : MMEDIUM; } From 015173da8d381d6a71269acde1dd92d124d02383 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Fri, 14 Apr 2023 08:12:29 +1000 Subject: [PATCH 16/31] java 17 switch statement Signed-off-by: Sally MacFarlane --- .../zktracer/module/alu/mul/MulData.java | 28 +++++++------------ .../zktracer/module/alu/mul/MulTracer.java | 11 +++++--- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 76ee0daa41..5cdb690f82 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -58,11 +58,6 @@ public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { arg2Hi = Bytes16.wrap(arg2.slice(0, 16)); arg2Lo = Bytes16.wrap(arg2.slice(16)); - // TODO what should these be initialized to (or is this not needed) - this.cBytes = null; - this.hBytes = null; - this.expAcc = UInt256.MIN_VALUE; - this.res = Res.create(opCode, arg1, arg2); // TODO can we get this from the EVM final BigInteger arg1BigInt = UInt256.fromBytes(arg1).toUnsignedBigInteger(); @@ -74,20 +69,17 @@ public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { final Regime regime = getRegime(); System.out.println(regime); switch (regime) { - case TRIVIAL_MUL: - break; - case NON_TRIVIAL_MUL: - cBytes = new BytesBaseTheta(res); - break; - case EXPONENT_ZERO_RESULT: - setArraysForZeroResultCase(); - break; - case EXPONENT_NON_ZERO_RESULT: + case TRIVIAL_MUL -> {} + case NON_TRIVIAL_MUL -> + cBytes = new BytesBaseTheta(res); + case EXPONENT_ZERO_RESULT -> + setArraysForZeroResultCase(); + case EXPONENT_NON_ZERO_RESULT -> { this.exponentBits = arg2.toBigInteger().toString(); snm = false; - break; - case IOTA: - throw new RuntimeException("alu/mul regime was never set"); + } + case IOTA -> + throw new RuntimeException("alu/mul regime was never set"); } } @@ -133,7 +125,7 @@ private void setArraysForZeroResultCase() { if (lowerBoundOnTwoAdicity != nu) { String s = String.format( - "2-adicity nu = %v != %v = lower bound on 2-adicity", nu, lowerBoundOnTwoAdicity); + "2-adicity nu = %d != %d = lower bound on 2-adicity", nu, lowerBoundOnTwoAdicity); throw new RuntimeException(s); } if (lowerBoundOnTwoAdicity == 0) { diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java index 14d25631a6..9651655359 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java @@ -53,13 +53,14 @@ public Object trace(MessageFrame frame) { stamp++; switch (data.getRegime()) { - case EXPONENT_ZERO_RESULT: + case EXPONENT_ZERO_RESULT -> { for (int ct = 0; ct < maxCt; ct++) { trace(builder, data, ct); } return builder.build(); + } - case EXPONENT_NON_ZERO_RESULT: + case EXPONENT_NON_ZERO_RESULT -> { if (data.carryOn()) { data.update(); for (int ct = 0; ct < maxCt; ct++) { @@ -67,15 +68,17 @@ public Object trace(MessageFrame frame) { } } return builder.build(); + } - case TRIVIAL_MUL, NON_TRIVIAL_MUL: + case TRIVIAL_MUL, NON_TRIVIAL_MUL -> { data.setHsAndBits(arg1.toBigInteger(), arg2.toBigInteger()); for (int ct = 0; ct < maxCt; ct++) { trace(builder, data, ct); } return builder.build(); + } - default: + default -> throw new RuntimeException("regime not supported"); } } From d88101af80d66931db8612fda4b0b9165d4e1a9e Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Fri, 14 Apr 2023 12:44:42 +1000 Subject: [PATCH 17/31] formatting Signed-off-by: Sally MacFarlane --- .../consensys/linea/zktracer/module/alu/mul/MulData.java | 9 +++------ .../linea/zktracer/module/alu/mul/MulTracer.java | 3 +-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 5cdb690f82..21d697e8f8 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -70,16 +70,13 @@ public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { System.out.println(regime); switch (regime) { case TRIVIAL_MUL -> {} - case NON_TRIVIAL_MUL -> - cBytes = new BytesBaseTheta(res); - case EXPONENT_ZERO_RESULT -> - setArraysForZeroResultCase(); + case NON_TRIVIAL_MUL -> cBytes = new BytesBaseTheta(res); + case EXPONENT_ZERO_RESULT -> setArraysForZeroResultCase(); case EXPONENT_NON_ZERO_RESULT -> { this.exponentBits = arg2.toBigInteger().toString(); snm = false; } - case IOTA -> - throw new RuntimeException("alu/mul regime was never set"); + case IOTA -> throw new RuntimeException("alu/mul regime was never set"); } } diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java index 9651655359..138d17feb9 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java @@ -78,8 +78,7 @@ public Object trace(MessageFrame frame) { return builder.build(); } - default -> - throw new RuntimeException("regime not supported"); + default -> throw new RuntimeException("regime not supported"); } } From 17340bb729264f650d0ce537ef96c5d1e87d8b0c Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Fri, 14 Apr 2023 12:44:59 +1000 Subject: [PATCH 18/31] header Signed-off-by: Sally MacFarlane --- .../linea/zktracer/bytes/BytesBaseTheta.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java b/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java index 32ee90a905..01bb053254 100644 --- a/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java +++ b/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java @@ -1,3 +1,18 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + package net.consensys.linea.zktracer.bytes; import java.math.BigInteger; From bb7467a2b05c0fb3c3837c6e80d680e719ba0595 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Fri, 14 Apr 2023 14:28:57 +1000 Subject: [PATCH 19/31] initialize values Signed-off-by: Sally MacFarlane --- .../linea/zktracer/bytes/BytesBaseTheta.java | 115 ----------------- .../linea/zktracer/bytestheta/BaseBytes.java | 6 +- .../linea/zktracer/bytestheta/BaseTheta.java | 2 +- .../consensys/linea/zktracer/module/Util.java | 23 +++- .../zktracer/module/alu/mul/MulData.java | 119 +++++++++--------- .../zktracer/module/alu/mul/MulTracer.java | 40 +++--- .../linea/zktracer/module/alu/mul/Muler.java | 29 ----- .../linea/zktracer/module/alu/mul/Res.java | 50 -------- 8 files changed, 104 insertions(+), 280 deletions(-) delete mode 100644 src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java delete mode 100644 src/main/java/net/consensys/linea/zktracer/module/alu/mul/Muler.java delete mode 100644 src/main/java/net/consensys/linea/zktracer/module/alu/mul/Res.java diff --git a/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java b/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java deleted file mode 100644 index 01bb053254..0000000000 --- a/src/main/java/net/consensys/linea/zktracer/bytes/BytesBaseTheta.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright ConsenSys AG. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - * - * SPDX-License-Identifier: Apache-2.0 - */ - -package net.consensys.linea.zktracer.bytes; - -import java.math.BigInteger; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.Arrays; - -import net.consensys.linea.zktracer.module.alu.mul.Res; -import org.apache.tuweni.bytes.Bytes32; - -public class BytesBaseTheta { - - private byte[][] bytes; - - public BytesBaseTheta(final Bytes32 arg) { - bytes = new byte[4][8]; - byte[] argBytes = arg.toArray(); - - for (int k = 0; k < 4; k++) { - System.arraycopy(argBytes, 8 * k, bytes[3 - k], 0, 8); - } - } - - public BytesBaseTheta(final Res res) { - bytes = new byte[4][8]; - byte[] argBytesHi = res.getResHi().toArray(); - byte[] argBytesLo = res.getResLo().toArray(); - - for (int k = 0; k < 2; k++) { - System.arraycopy(argBytesHi, 8 * k, bytes[3 - k], 0, 8); - } - for (int k = 2; k < 4; k++) { - System.arraycopy(argBytesLo, 8 * (k - 2), bytes[3 - k], 0, 8); - } - } - - public void set(final BigInteger bigInteger) { - // TODO how to get from BigInteger to bytes - bigInteger.toByteArray(); - } - - public void set(final int i, final BigInteger bigInteger) { - // TODO handle underflow - byte[] bigIntByteArray = bigInteger.toByteArray(); - System.arraycopy(bigIntByteArray, 0, bytes[i], 0, 8); - } - - public void set(final int i, final byte[] chunk) { - // TODO handle underflow - System.arraycopy(chunk, 0, bytes[i], 0, 8); - } - - // TODO can Res become Pair as below - public Pair getHiLo() { - byte[] hiBytes = new byte[16]; - byte[] loBytes = new byte[16]; - - System.arraycopy(bytes[3], 0, hiBytes, 0, 8); - System.arraycopy(bytes[2], 0, hiBytes, 8, 8); - - System.arraycopy(bytes[1], 0, loBytes, 0, 8); - System.arraycopy(bytes[0], 0, loBytes, 8, 8); - - return new Pair<>(hiBytes, loBytes); - } - - public byte[] getChunk(final int i) { - return bytes[i]; - } - - public byte get(final int i, final int j) { - return bytes[i][j]; - } - - public byte[] getRange(final int i, final int start, final int end) { - return Arrays.copyOfRange(bytes[i], start, end); - } - - public void set(int i, int j, byte b) { - bytes[i][j] = b; - } -} - -@SuppressWarnings("UnusedVariable") -record Pair(A first, B second) {} - -class UInt256 { - private byte[] bytes; - - public UInt256(byte[] bytes) { - this.bytes = bytes; - } - - public byte[] getBytes32() { - ByteBuffer buf = ByteBuffer.allocate(32); - buf.order(ByteOrder.BIG_ENDIAN); - buf.put(bytes); - return buf.array(); - } -} diff --git a/src/main/java/net/consensys/linea/zktracer/bytestheta/BaseBytes.java b/src/main/java/net/consensys/linea/zktracer/bytestheta/BaseBytes.java index 6fbeacefd3..4a5412a70e 100644 --- a/src/main/java/net/consensys/linea/zktracer/bytestheta/BaseBytes.java +++ b/src/main/java/net/consensys/linea/zktracer/bytestheta/BaseBytes.java @@ -24,7 +24,7 @@ public class BaseBytes { private final int LOW_HIGH_SIZE = 16; protected MutableBytes32 bytes32; - static BaseBytes fromBytes32(Bytes32 arg) { + public static BaseBytes fromBytes32(Bytes32 arg) { return new BaseBytes(arg); } @@ -47,4 +47,8 @@ public byte getByte(int index) { public Bytes32 getBytes32() { return bytes32; } + + public boolean isZero() { + return bytes32.isZero(); + } } diff --git a/src/main/java/net/consensys/linea/zktracer/bytestheta/BaseTheta.java b/src/main/java/net/consensys/linea/zktracer/bytestheta/BaseTheta.java index 2ff1efaee1..4a59380afa 100644 --- a/src/main/java/net/consensys/linea/zktracer/bytestheta/BaseTheta.java +++ b/src/main/java/net/consensys/linea/zktracer/bytestheta/BaseTheta.java @@ -29,7 +29,7 @@ private BaseTheta(final Bytes32 arg) { } } - static BaseTheta fromBytes32(Bytes32 arg) { + public static BaseTheta fromBytes32(Bytes32 arg) { return new BaseTheta(arg); } diff --git a/src/main/java/net/consensys/linea/zktracer/module/Util.java b/src/main/java/net/consensys/linea/zktracer/module/Util.java index 6e9784d03a..45ad63477c 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/Util.java +++ b/src/main/java/net/consensys/linea/zktracer/module/Util.java @@ -1,8 +1,21 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ package net.consensys.linea.zktracer.module; -import java.math.BigInteger; - import net.consensys.linea.zktracer.bytes.UnsignedByte; +import org.apache.tuweni.units.bigints.UInt256; import org.apache.tuweni.units.bigints.UInt64; public class Util { @@ -22,9 +35,9 @@ public static Boolean[] byteBits(final UnsignedByte b) { return bits; } - public static int getOverflow(final BigInteger arg, final int maxVal, final String err) { - BigInteger shiftRight = arg.shiftRight(128); - if (shiftRight.compareTo(UInt64.MAX_VALUE.toBigInteger()) > 0) { + public static int getOverflow(final UInt256 arg, final int maxVal, final String err) { + UInt256 shiftRight = arg.shiftRight(128); + if (shiftRight.compareTo(UInt64.MAX_VALUE.toBytes()) > 0) { throw new RuntimeException("getOverflow expects a small high part"); } int overflow = shiftRight.intValue(); diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 21d697e8f8..62d23eaf6d 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -10,8 +10,9 @@ import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.bytes.Bytes16; -import net.consensys.linea.zktracer.bytes.BytesBaseTheta; import net.consensys.linea.zktracer.bytes.UnsignedByte; +import net.consensys.linea.zktracer.bytestheta.BaseBytes; +import net.consensys.linea.zktracer.bytestheta.BaseTheta; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; @@ -31,34 +32,36 @@ public class MulData { final boolean tinyBase; final boolean tinyExponent; - BigInteger resAcc; // accumulator which converges in a series of "square and multiply"'s - UInt256 expAcc; // accumulator for doubles and adds of the exponent, resets at some point + UInt256 resAcc = + UInt256.ZERO; // accumulator which converges in a series of "square and multiply"'s + UInt256 expAcc = + UInt256.ZERO; // accumulator for doubles and adds of the exponent, resets at some point - final BytesBaseTheta aBytes; - final BytesBaseTheta bBytes; - BytesBaseTheta cBytes; - BytesBaseTheta hBytes; + final BaseTheta aBytes; + final BaseTheta bBytes; + BaseTheta cBytes = BaseTheta.fromBytes32(Bytes32.ZERO); + BaseTheta hBytes = BaseTheta.fromBytes32(Bytes32.ZERO); boolean snm = false; int index; - Boolean[] bits; - String exponentBits; + Boolean[] bits = new Boolean[8]; + String exponentBits = "0"; - Res res; + BaseBytes res; public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { this.opCode = opCode; this.arg1 = arg1; this.arg2 = arg2; - this.aBytes = new BytesBaseTheta(arg1); - this.bBytes = new BytesBaseTheta(arg2); + this.aBytes = BaseTheta.fromBytes32(arg1); + this.bBytes = BaseTheta.fromBytes32(arg2); arg1Hi = Bytes16.wrap(arg1.slice(0, 16)); arg1Lo = Bytes16.wrap(arg1.slice(16)); arg2Hi = Bytes16.wrap(arg2.slice(0, 16)); arg2Lo = Bytes16.wrap(arg2.slice(16)); - this.res = Res.create(opCode, arg1, arg2); // TODO can we get this from the EVM + this.res = getRes(opCode, arg1, arg2); // TODO can we get this from the EVM final BigInteger arg1BigInt = UInt256.fromBytes(arg1).toUnsignedBigInteger(); final BigInteger arg2BigInt = UInt256.fromBytes(arg2).toUnsignedBigInteger(); @@ -67,10 +70,9 @@ public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { this.tinyExponent = isTiny(arg2BigInt); final Regime regime = getRegime(); - System.out.println(regime); switch (regime) { case TRIVIAL_MUL -> {} - case NON_TRIVIAL_MUL -> cBytes = new BytesBaseTheta(res); + case NON_TRIVIAL_MUL -> cBytes = BaseTheta.fromBytes32(res.getBytes32()); case EXPONENT_ZERO_RESULT -> setArraysForZeroResultCase(); case EXPONENT_NON_ZERO_RESULT -> { this.exponentBits = arg2.toBigInteger().toString(); @@ -80,6 +82,15 @@ public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { } } + private static BaseBytes getRes(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { + + return switch (opCode) { + case MUL -> BaseBytes.fromBytes32(UInt256.fromBytes(arg1).multiply(UInt256.fromBytes(arg2))); + case EXP -> BaseBytes.fromBytes32(UInt256.fromBytes(arg1).pow(UInt256.fromBytes(arg2))); + default -> BaseBytes.fromBytes32(UInt256.ZERO); + }; + } + private void setArraysForZeroResultCase() { int nu = twoAdicity(arg1); @@ -87,20 +98,20 @@ private void setArraysForZeroResultCase() { return; } - byte[] ones = Bytes.repeat((byte) 1, 8).toArray(); - byte[] bytes; + Bytes ones = Bytes.repeat((byte) 1, 8); + Bytes bytes; if (128 > nu && nu >= 64) { - bytes = aBytes.getChunk(1); + bytes = aBytes.get(1); } else { for (int i = 0; i < 8; i++) { - cBytes.set(0, ones); + cBytes.setBytes(0, ones); } - bytes = aBytes.getChunk(0); + bytes = aBytes.get(0); } int nuQuo = (nu / 8) % 8; int nuRem = nu % 8; - byte pivotByte = bytes[7 - nuQuo]; + byte pivotByte = bytes.get(7 - nuQuo); for (int i = 0; i < 8; i++) { cBytes.set(1, i, pivotByte); @@ -129,7 +140,7 @@ private void setArraysForZeroResultCase() { throw new RuntimeException("lower bound on 2 adicity == 0 in the zero result case"); } - UInt256 twoFiftySix = UInt256.valueOf(256); + final UInt256 twoFiftySix = UInt256.valueOf(256); if (arg2.compareTo(twoFiftySix) >= 0) { // arg2 = exponent >= 256 hBytes.set(1, 6, (byte) ((lowerBoundOnTwoAdicity - 1) / 256)); @@ -147,9 +158,8 @@ private void setArraysForZeroResultCase() { throw new RuntimeException("something went awfully wrong ..."); } - final BytesBaseTheta thing = - new BytesBaseTheta(Bytes32.wrap(BigInteger.valueOf(target).toByteArray())); - hBytes.set(1, thing.getChunk(0)); + final BaseTheta thing = BaseTheta.fromBytes32(UInt256.valueOf(target)); + hBytes.setBytes(1, thing.get(0)); } return; @@ -233,8 +243,8 @@ public boolean carryOn() { // first round is special if (index == 0 && !snm) { snm = true; - resAcc = BigInteger.ONE; // TODO assuming this is what SetOne() does - cBytes.set(arg1.toBigInteger()); + resAcc = UInt256.valueOf(1); // TODO assuming this is what SetOne() does + cBytes = BaseTheta.fromBytes32(arg1); return true; } @@ -278,40 +288,36 @@ public void update() { resAcc = resAcc.multiply(resAcc); } else { // multiplying by base - setHsAndBits(arg1BigInt, resAcc); + setHsAndBits(UInt256.valueOf(arg1BigInt), resAcc); expAcc = expAcc.add(UInt256.ONE); - resAcc = arg1BigInt.multiply(resAcc); + resAcc = UInt256.valueOf(arg1BigInt).multiply(resAcc); } - cBytes.set(resAcc); + cBytes = BaseTheta.fromBytes32(resAcc); } - public void setHsAndBits(BigInteger a, BigInteger b) { + public void setHsAndBits(UInt256 a, UInt256 b) { - BytesBaseTheta aBaseTheta, bBaseTheta, sumBaseTheta; - aBaseTheta = new BytesBaseTheta(Bytes32.ZERO); - bBaseTheta = new BytesBaseTheta(Bytes32.ZERO); - sumBaseTheta = new BytesBaseTheta(Bytes32.ZERO); + BaseTheta aBaseTheta = BaseTheta.fromBytes32(a); + BaseTheta bBaseTheta = BaseTheta.fromBytes32(b); + BaseTheta sumBaseTheta; - aBaseTheta.set(a); - bBaseTheta.set(b); - - BigInteger[] aBaseThetaInts = (BigInteger[]) Array.newInstance(UInt256.class, 4); - BigInteger[] bBaseThetaInts = (BigInteger[]) Array.newInstance(UInt256.class, 4); + UInt256[] aBaseThetaInts = (UInt256[]) Array.newInstance(UInt256.class, 4); + UInt256[] bBaseThetaInts = (UInt256[]) Array.newInstance(UInt256.class, 4); for (int i = 0; i < 4; i++) { - aBaseThetaInts[i] = Bytes.of(aBaseTheta.getChunk(i)).toBigInteger(); - bBaseThetaInts[i] = Bytes.of(bBaseTheta.getChunk(i)).toBigInteger(); + aBaseThetaInts[i] = UInt256.fromBytes(aBaseTheta.get(i)); + bBaseThetaInts[i] = UInt256.fromBytes(bBaseTheta.get(i)); } - BigInteger sum, prod; + UInt256 sum, prod; prod = aBaseThetaInts[1].multiply(bBaseThetaInts[0]); sum = prod; // sum := a1 * b0 prod = aBaseThetaInts[0].multiply(bBaseThetaInts[1]); sum = sum.add(prod); // sum += a0 * b1 - sumBaseTheta.set(sum); - hBytes.set(0, sumBaseTheta.getChunk(0)); - hBytes.set(1, sumBaseTheta.getChunk(1)); + sumBaseTheta = BaseTheta.fromBytes32(sum); + hBytes.setBytes(0, sumBaseTheta.get(0)); + hBytes.setBytes(1, sumBaseTheta.get(1)); int alpha = getOverflow(sum, 1, "alpha OOB"); sum = aBaseThetaInts[3].multiply(bBaseThetaInts[0]); // sum := a3 * b0 @@ -322,26 +328,26 @@ public void setHsAndBits(BigInteger a, BigInteger b) { prod = aBaseThetaInts[0].multiply(bBaseThetaInts[3]); sum = sum.add(prod); // sum += a0 * b3 - sumBaseTheta.set(sum); - hBytes.set(2, sumBaseTheta.getChunk(0)); - hBytes.set(3, sumBaseTheta.getChunk(1)); + sumBaseTheta = BaseTheta.fromBytes32(sum); + hBytes.setBytes(2, sumBaseTheta.get(0)); + hBytes.setBytes(3, sumBaseTheta.get(1)); int beta = getOverflow(sum, 3, "beta OOB"); sum = aBaseThetaInts[0].multiply(bBaseThetaInts[0]); // sum := a0 * b0 - sum = sum.add(shiftLeft64(hBytes.getChunk(0))); // sum += (h0 << 64) + sum = sum.add(UInt256.fromBytes(hBytes.get(0).shiftLeft(64))); // sum += (h0 << 64) int eta = getOverflow(sum, 1, "eta OOB"); - sum = BigInteger.valueOf(eta); // sum := eta - sum = sum.add(Bytes16.wrap(hBytes.getChunk(1)).toUnsignedBigInteger()); // sum += h1 - sum = sum.add(BigInteger.valueOf(alpha).shiftLeft(64)); // sum += (alpha << 64) + sum = UInt256.valueOf(eta); // sum := eta + sum = sum.add(UInt256.fromBytes(hBytes.get(1))); // sum += h1 + sum = sum.add(UInt256.valueOf(alpha).shiftLeft(64)); // sum += (alpha << 64) prod = aBaseThetaInts[2].multiply(bBaseThetaInts[0]); sum = sum.add(prod); // sum += a2 * b0 prod = aBaseThetaInts[1].multiply(bBaseThetaInts[1]); sum = sum.add(prod); // sum += a1 * b1 prod = aBaseThetaInts[0].multiply(bBaseThetaInts[2]); sum = sum.add(prod); // sum += a0 * b2 - sum = sum.add(shiftLeft64(hBytes.getChunk(2))); // sum += (h2 << 64) + sum = sum.add(UInt256.fromBytes(hBytes.get(2).shiftLeft(64))); // sum += (h2 << 64) int mu = getOverflow(sum, 3, "mu OOB"); @@ -355,11 +361,6 @@ public void setHsAndBits(BigInteger a, BigInteger b) { return; } - private BigInteger shiftLeft64(byte[] b16) { - final Bytes16 copy = Bytes16.wrap(b16).copy(); - return copy.shiftLeft(64).toUnsignedBigInteger(); - } - // hiToLoExponentBitAccumulatorReset resets the exponent bit accumulator // under the following conditions: // - we are dealing with the high part of the exponent bits, i.e. md.exponentBit() = 0 diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java index 138d17feb9..7e1a051852 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java @@ -21,8 +21,8 @@ import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.bytes.UnsignedByte; import net.consensys.linea.zktracer.module.ModuleTracer; -import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; +import org.apache.tuweni.units.bigints.UInt256; public class MulTracer implements ModuleTracer { @@ -71,7 +71,7 @@ public Object trace(MessageFrame frame) { } case TRIVIAL_MUL, NON_TRIVIAL_MUL -> { - data.setHsAndBits(arg1.toBigInteger(), arg2.toBigInteger()); + data.setHsAndBits(UInt256.fromBytes(arg1), UInt256.fromBytes(arg2)); for (int ct = 0; ct < maxCt; ct++) { trace(builder, data, ct); } @@ -100,8 +100,8 @@ private void trace(final MulTrace.Trace.Builder builder, final MulData data, fin .appendArg2Lo(data.arg2Lo.toUnsignedBigInteger()); builder - .appendResHi(data.res.getResHi().toUnsignedBigInteger()) - .appendResLo(data.res.getResLo().toUnsignedBigInteger()); + .appendResHi(data.res.getHigh().toUnsignedBigInteger()) + .appendResLo(data.res.getLow().toUnsignedBigInteger()); builder.appendBits(data.bits[i]); @@ -111,10 +111,10 @@ private void trace(final MulTrace.Trace.Builder builder, final MulData data, fin .appendByteA1(UnsignedByte.of(data.aBytes.get(1, i))) .appendByteA0(UnsignedByte.of(data.aBytes.get(0, i))); builder - .appendAccA3(Bytes.of(data.aBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccA2(Bytes.of(data.aBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccA1(Bytes.of(data.aBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccA0(Bytes.of(data.aBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + .appendAccA3(data.aBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) + .appendAccA2(data.aBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) + .appendAccA1(data.aBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) + .appendAccA0(data.aBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); builder .appendByteB3(UnsignedByte.of(data.bBytes.get(3, i))) @@ -122,20 +122,20 @@ private void trace(final MulTrace.Trace.Builder builder, final MulData data, fin .appendByteB1(UnsignedByte.of(data.bBytes.get(1, i))) .appendByteB0(UnsignedByte.of(data.bBytes.get(0, i))); builder - .appendAccB3(Bytes.of(data.bBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB2(Bytes.of(data.bBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB1(Bytes.of(data.bBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB0(Bytes.of(data.bBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + .appendAccB3(data.bBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) + .appendAccB2(data.bBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) + .appendAccB1(data.bBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) + .appendAccB0(data.bBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); builder .appendByteC3(UnsignedByte.of(data.cBytes.get(3, i))) .appendByteC2(UnsignedByte.of(data.cBytes.get(2, i))) .appendByteC1(UnsignedByte.of(data.cBytes.get(1, i))) .appendByteC0(UnsignedByte.of(data.cBytes.get(0, i))); builder - .appendAccB3(Bytes.of(data.cBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB2(Bytes.of(data.cBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB1(Bytes.of(data.cBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB0(Bytes.of(data.cBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + .appendAccB3(data.cBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) + .appendAccB2(data.cBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) + .appendAccB1(data.cBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) + .appendAccB0(data.cBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); builder .appendByteH3(UnsignedByte.of(data.hBytes.get(3, i))) @@ -143,10 +143,10 @@ private void trace(final MulTrace.Trace.Builder builder, final MulData data, fin .appendByteH1(UnsignedByte.of(data.hBytes.get(1, i))) .appendByteH0(UnsignedByte.of(data.hBytes.get(0, i))); builder - .appendAccB3(Bytes.of(data.hBytes.getRange(3, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB2(Bytes.of(data.hBytes.getRange(2, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB1(Bytes.of(data.hBytes.getRange(1, 0, i + 1)).toUnsignedBigInteger()) - .appendAccB0(Bytes.of(data.hBytes.getRange(0, 0, i + 1)).toUnsignedBigInteger()); + .appendAccB3(data.hBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) + .appendAccB2(data.hBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) + .appendAccB1(data.hBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) + .appendAccB0(data.hBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); builder .appendExponentBit(data.exponentBit()) .appendExponentBitAcc(data.expAcc.toUnsignedBigInteger()) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/Muler.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/Muler.java deleted file mode 100644 index 95288c2107..0000000000 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/Muler.java +++ /dev/null @@ -1,29 +0,0 @@ -package net.consensys.linea.zktracer.module.alu.mul; -/* - * Copyright ConsenSys AG. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - * - * SPDX-License-Identifier: Apache-2.0 - */ -import net.consensys.linea.zktracer.OpCode; -import org.apache.tuweni.bytes.Bytes32; -import org.apache.tuweni.units.bigints.UInt256; - -public class Muler { - - public static UInt256 operate(final OpCode opCode, final Bytes32 arg1, final Bytes32 arg2) { - return switch (opCode) { - case MUL -> UInt256.fromBytes(arg1).multiply(UInt256.fromBytes(arg2)); - case EXP -> UInt256.fromBytes(arg1).pow(UInt256.fromBytes(arg2)); - default -> UInt256.ZERO; - }; - } -} diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/Res.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/Res.java deleted file mode 100644 index cb964dfe09..0000000000 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/Res.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright ConsenSys AG. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - * - * SPDX-License-Identifier: Apache-2.0 - */ -package net.consensys.linea.zktracer.module.alu.mul; - -import net.consensys.linea.zktracer.OpCode; -import net.consensys.linea.zktracer.bytes.Bytes16; -import org.apache.tuweni.bytes.Bytes32; - -public class Res { - final Bytes16 resHi; - final Bytes16 resLo; - final boolean isZero; - - private Res(Bytes16 resHi, Bytes16 resLo, boolean isZero) { - this.resHi = resHi; - this.resLo = resLo; - this.isZero = isZero; - } - - public Bytes16 getResHi() { - return resHi; - } - - public Bytes16 getResLo() { - return resLo; - } - - public static Res create(final OpCode opCode, final Bytes32 arg1, final Bytes32 arg2) { - final Bytes32 result = Muler.operate(opCode, arg1, arg2); - - return new Res( - Bytes16.wrap(result.slice(0, 16)), Bytes16.wrap(result.slice(16)), result.isZero()); - } - - public boolean isZero() { - return isZero; - } -} From b49827062b9b56546b30e95df149b849004d1c50 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Mon, 17 Apr 2023 12:19:38 +1000 Subject: [PATCH 20/31] formatting Signed-off-by: Sally MacFarlane --- src/main/java/net/consensys/linea/zktracer/ZkTracer.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/ZkTracer.java b/src/main/java/net/consensys/linea/zktracer/ZkTracer.java index 5e7c663c39..4f5f1df711 100644 --- a/src/main/java/net/consensys/linea/zktracer/ZkTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/ZkTracer.java @@ -23,8 +23,8 @@ import net.consensys.linea.zktracer.module.ModuleTracer; import net.consensys.linea.zktracer.module.alu.add.AddTracer; -import net.consensys.linea.zktracer.module.alu.mul.MulTracer; import net.consensys.linea.zktracer.module.alu.mod.ModTracer; +import net.consensys.linea.zktracer.module.alu.mul.MulTracer; import net.consensys.linea.zktracer.module.shf.ShfTracer; import net.consensys.linea.zktracer.module.wcp.WcpTracer; @@ -43,7 +43,8 @@ public ZkTracer(final ZkTraceBuilder zkTraceBuilder, final List tr public ZkTracer(final ZkTraceBuilder zkTraceBuilder) { this( zkTraceBuilder, - List.of(new MulTracer(), new ShfTracer(), new WcpTracer(), new AddTracer(), new ModTracer())); + List.of( + new MulTracer(), new ShfTracer(), new WcpTracer(), new AddTracer(), new ModTracer())); } @Override From e2c231442604d84b25cef89981b9b0dfb49f994f Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 19 Apr 2023 15:32:34 +1000 Subject: [PATCH 21/31] converted to extend abstract test class Signed-off-by: Sally MacFarlane --- .../module/alu/mul/MulTracerTest.java | 191 +++++++----------- 1 file changed, 72 insertions(+), 119 deletions(-) diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java index bdc92029d7..5f0ebf7fee 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java @@ -14,124 +14,85 @@ */ package net.consensys.linea.zktracer.module.alu.mul; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.mockito.Mockito.when; - -import org.hyperledger.besu.evm.frame.MessageFrame; -import org.hyperledger.besu.evm.operation.Operation; - +import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.stream.Stream; -import net.consensys.linea.CorsetValidator; import net.consensys.linea.zktracer.OpCode; -import net.consensys.linea.zktracer.ZkTraceBuilder; -import net.consensys.linea.zktracer.ZkTracer; +import net.consensys.linea.zktracer.module.AbstractModuleTracerTest; +import net.consensys.linea.zktracer.module.ModuleTracer; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Named; -import org.junit.jupiter.api.Test; +import org.apache.tuweni.units.bigints.UInt256; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; @ExtendWith(MockitoExtension.class) -class MulTracerTest { - private static final Logger LOG = LoggerFactory.getLogger(MulTracerTest.class); - +class MulTracerTest extends AbstractModuleTracerTest { private static final Random rand = new Random(); - private static final int TEST_REPETITIONS = 4; - private ZkTracer zkTracer; - private ZkTraceBuilder zkTraceBuilder; + private static final int TEST_MUL_REPETITIONS = 16; - @Mock MessageFrame mockFrame; - @Mock Operation mockOperation; - - @BeforeEach - void setUp() { - zkTraceBuilder = new ZkTraceBuilder(); - zkTracer = new ZkTracer(zkTraceBuilder, List.of(new MulTracer())); - - when(mockFrame.getCurrentOperation()).thenReturn(mockOperation); + @ParameterizedTest() + @MethodSource("provideRandomAluMulArguments") + void aluMulTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, arg1, arg2); } - @ParameterizedTest(name = "{0}") - @MethodSource("provideMulOperators") - void testFailingBlockchainBlock(final int opCodeValue) { - when(mockOperation.getOpcode()).thenReturn(opCodeValue); - - when(mockFrame.getStackItem(0)).thenReturn(Bytes32.rightPad(Bytes.fromHexString("0x08"))); - when(mockFrame.getStackItem(1)).thenReturn(Bytes32.fromHexString("0x0128")); - - zkTracer.tracePreExecution(mockFrame); - - assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + @ParameterizedTest() + @MethodSource("provideSimpleAluMulArguments") + void simpleTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, arg1, arg2); } - @ParameterizedTest(name = "{0}") - @MethodSource("provideRandomArguments") - void testRandomExp(final Bytes32[] payload) { - LOG.info("arg1: " + payload[0].toShortHexString() + ", arg2: " + payload[1].toShortHexString()); - when(mockOperation.getOpcode()).thenReturn((int) OpCode.EXP.value); - - when(mockFrame.getStackItem(0)).thenReturn(payload[0]); - when(mockFrame.getStackItem(1)).thenReturn(payload[1]); - - zkTracer.tracePreExecution(mockFrame); - - assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + @ParameterizedTest() + @MethodSource("provideTinyArguments") + void tinyArgsTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, arg1, arg2); } - @ParameterizedTest(name = "{0}") - @MethodSource("provideNonRandomTinyArguments") - void testNonRandomTinyMul(final Bytes32[] payload) { - LOG.info("arg1: " + payload[0].toShortHexString() + ", arg2: " + payload[1].toShortHexString()); - when(mockOperation.getOpcode()).thenReturn((int) OpCode.EXP.value); - - when(mockFrame.getStackItem(0)).thenReturn(payload[0]); - when(mockFrame.getStackItem(1)).thenReturn(payload[1]); - - zkTracer.tracePreExecution(mockFrame); - - assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + @ParameterizedTest() + @MethodSource("provideSpecificNonTinyArguments") + void nonTinyArgsTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, arg1, arg2); } - @ParameterizedTest(name = "{0}") - @MethodSource("provideNonRandomNonTinyArguments") - void testNonRandomNonTinyMul(final Bytes32[] payload) { - LOG.info("arg1: " + payload[0].toShortHexString() + ", arg2: " + payload[1].toShortHexString()); - when(mockOperation.getOpcode()).thenReturn((int) OpCode.EXP.value); - - when(mockFrame.getStackItem(0)).thenReturn(payload[0]); - when(mockFrame.getStackItem(1)).thenReturn(payload[1]); + @ParameterizedTest() + @MethodSource("provideRandomNonTinyArguments") + void randomNonTinyArgsTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, arg1, arg2); + } - zkTracer.tracePreExecution(mockFrame); + public Stream provideSimpleAluMulArguments() { + List arguments = new ArrayList<>(); - assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + Bytes32 bytes1 = Bytes32.rightPad(Bytes.fromHexString("0x80")); + Bytes32 bytes2 = Bytes32.leftPad(Bytes.fromHexString("0x01")); + arguments.add(Arguments.of(getRandomSupportedOpcode(), bytes1, bytes2)); + return arguments.stream(); } - @Test - void testSimpleMul() { - when(mockOperation.getOpcode()).thenReturn((int) OpCode.MUL.value); - - when(mockFrame.getStackItem(0)) - .thenReturn(Bytes32.fromHexStringLenient("0x54fda4f3c1452c8c58df4fb1e9d6de")); - when(mockFrame.getStackItem(1)).thenReturn(Bytes32.fromHexStringLenient("0xb5")); + public Stream provideRandomAluMulArguments() { + List arguments = new ArrayList<>(); - zkTracer.tracePreExecution(mockFrame); + for (int i = 0; i < TEST_MUL_REPETITIONS; i++) { + arguments.add(getRandomAluMulInstruction(rand.nextInt(32) + 1, rand.nextInt(32) + 1)); + } + return arguments.stream(); + } - assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + private Arguments getRandomAluMulInstruction(int sizeArg1MinusOne, int sizeArg2MinusOne) { + Bytes32 bytes1 = UInt256.valueOf(sizeArg1MinusOne); + Bytes32 bytes2 = UInt256.valueOf(sizeArg2MinusOne); + OpCode opCode = getRandomSupportedOpcode(); + return Arguments.of(opCode, bytes1, bytes2); } - public static Stream provideNonRandomNonTinyArguments() { + public Stream provideSpecificNonTinyArguments() { // these values are used in Go module test // 0x8a, 0x48, 0xaa, 0x20, 0xe2, 0x00, 0xce, 0x3f, 0xee, 0x16, 0xb5, 0xdc, 0xde, 0xc5, 0xc4, // 0xfa, @@ -147,51 +108,43 @@ public static Stream provideNonRandomNonTinyArguments() { Bytes32.fromHexString("0x8a48aa20e200ce3fee16b5dcdec5c4faff613bc914d47cd6ca69553f8eb2b377"); payload[1] = Bytes32.fromHexString("0x59b635fec894caa3ed6817b1e67b3cbaeb8757fd6c7b03119b795303b7cd72c1"); - return Stream.of( - Arguments.of(Named.of("arg1: " + payload[0] + ", arg2: " + payload[1], payload))); + return Stream.of(Arguments.of(getRandomSupportedOpcode(), payload[0], payload[1])); } - public static Stream provideNonRandomTinyArguments() { - final Arguments[] arguments = new Arguments[TEST_REPETITIONS]; + public Stream provideRandomNonTinyArguments() { + List arguments = new ArrayList<>(); - for (int i = 0; i < TEST_REPETITIONS; i++) { - Bytes32[] payload = new Bytes32[2]; - payload[0] = Bytes32.leftPad(Bytes.of(1 + i)); - payload[1] = Bytes32.leftPad(Bytes.of(i)); - arguments[i] = - Arguments.of(Named.of("arg1: " + payload[0] + ", arg2: " + payload[1], payload)); + for (int i = 0; i < TEST_MUL_REPETITIONS; i++) { + arguments.add(getRandomAluMulInstruction(rand.nextInt(32) + 1, rand.nextInt(32) + 1)); } - - return Stream.of(arguments); + return arguments.stream(); } - public static Stream provideRandomArguments() { - final Arguments[] arguments = new Arguments[TEST_REPETITIONS]; - - for (int i = 0; i < TEST_REPETITIONS; i++) { + public Stream provideTinyArguments() { + List arguments = new ArrayList<>(); - final byte[] randomBytes1 = new byte[32]; - rand.nextBytes(randomBytes1); - final byte[] randomBytes2 = new byte[32]; - rand.nextBytes(randomBytes2); + for (int i = 0; i < 4; i++) { + arguments.add(getRandomAluMulInstruction(i, i + 1)); + } - Bytes32[] payload = new Bytes32[2]; - payload[0] = Bytes32.wrap(randomBytes1); - payload[1] = Bytes32.wrap(randomBytes2); + return arguments.stream(); + } - arguments[i] = - Arguments.of( - Named.of( - "arg1: " + payload[0].toHexString() + ", arg2: " + payload[1].toHexString(), - payload)); + @Override + protected Stream provideNonRandomArguments() { + List arguments = new ArrayList<>(); + for (OpCode opCode : getModuleTracer().supportedOpCodes()) { + for (int k = 1; k <= 4; k++) { + for (int i = 1; i <= 4; i++) { + arguments.add(Arguments.of(opCode, UInt256.valueOf(i), UInt256.valueOf(k))); + } + } } - - return Stream.of(arguments); + return arguments.stream(); } - public static Stream provideMulOperators() { - return Stream.of( - Arguments.of(Named.of("MUL", (int) OpCode.MUL.value)), - Arguments.of(Named.of("EXP", (int) OpCode.EXP.value))); + @Override + protected ModuleTracer getModuleTracer() { + return new MulTracer(); } } From 3a8c14c906aa1b0d21f8ee13e9cc35a748183ba6 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Thu, 20 Apr 2023 12:09:57 +1000 Subject: [PATCH 22/31] tests passing and corset validator passing Signed-off-by: Sally MacFarlane --- .../zktracer/module/alu/mul/MulData.java | 5 +++ .../zktracer/module/alu/mul/MulTracer.java | 16 +++---- .../module/alu/mul/MulTracerTest.java | 35 +++++++++++----- .../zktracer/module/alu/mul/MulUtilsTest.java | 42 +++++++++++++++++++ 4 files changed, 80 insertions(+), 18 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 62d23eaf6d..2eee5ae2bb 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -69,6 +69,11 @@ public MulData(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { this.tinyBase = isTiny(arg1BigInt); this.tinyExponent = isTiny(arg2BigInt); + // initialize bits + for (int i = 0; i < bits.length; i++) { + bits[i] = false; + } + final Regime regime = getRegime(); switch (regime) { case TRIVIAL_MUL -> {} diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java index 7e1a051852..be5f5afdba 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java @@ -132,10 +132,10 @@ private void trace(final MulTrace.Trace.Builder builder, final MulData data, fin .appendByteC1(UnsignedByte.of(data.cBytes.get(1, i))) .appendByteC0(UnsignedByte.of(data.cBytes.get(0, i))); builder - .appendAccB3(data.cBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) - .appendAccB2(data.cBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) - .appendAccB1(data.cBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) - .appendAccB0(data.cBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); + .appendAccC3(data.cBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) + .appendAccC2(data.cBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) + .appendAccC1(data.cBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) + .appendAccC0(data.cBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); builder .appendByteH3(UnsignedByte.of(data.hBytes.get(3, i))) @@ -143,10 +143,10 @@ private void trace(final MulTrace.Trace.Builder builder, final MulData data, fin .appendByteH1(UnsignedByte.of(data.hBytes.get(1, i))) .appendByteH0(UnsignedByte.of(data.hBytes.get(0, i))); builder - .appendAccB3(data.hBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) - .appendAccB2(data.hBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) - .appendAccB1(data.hBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) - .appendAccB0(data.hBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); + .appendAccH3(data.hBytes.getRange(3, 0, i + 1).toUnsignedBigInteger()) + .appendAccH2(data.hBytes.getRange(2, 0, i + 1).toUnsignedBigInteger()) + .appendAccH1(data.hBytes.getRange(1, 0, i + 1).toUnsignedBigInteger()) + .appendAccH0(data.hBytes.getRange(0, 0, i + 1).toUnsignedBigInteger()); builder .appendExponentBit(data.exponentBit()) .appendExponentBitAcc(data.expAcc.toUnsignedBigInteger()) diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java index 5f0ebf7fee..5fbc1dde61 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java @@ -44,8 +44,8 @@ void aluMulTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { } @ParameterizedTest() - @MethodSource("provideSimpleAluMulArguments") - void simpleTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + @MethodSource("singleTinyExponentiation") + void testSingleTinyExponentiation(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { runTest(opCode, arg1, arg2); } @@ -67,12 +67,18 @@ void randomNonTinyArgsTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { runTest(opCode, arg1, arg2); } - public Stream provideSimpleAluMulArguments() { + @ParameterizedTest() + @MethodSource("multiplyByZero") + void zerosArgsTest(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { + runTest(opCode, arg1, arg2); + } + + public Stream singleTinyExponentiation() { List arguments = new ArrayList<>(); - Bytes32 bytes1 = Bytes32.rightPad(Bytes.fromHexString("0x80")); - Bytes32 bytes2 = Bytes32.leftPad(Bytes.fromHexString("0x01")); - arguments.add(Arguments.of(getRandomSupportedOpcode(), bytes1, bytes2)); + Bytes32 bytes1 = Bytes32.leftPad(Bytes.fromHexString("0x13")); + Bytes32 bytes2 = Bytes32.leftPad(Bytes.fromHexString("0x02")); + arguments.add(Arguments.of(OpCode.EXP, bytes1, bytes2)); return arguments.stream(); } @@ -122,11 +128,9 @@ public Stream provideRandomNonTinyArguments() { public Stream provideTinyArguments() { List arguments = new ArrayList<>(); - for (int i = 0; i < 4; i++) { arguments.add(getRandomAluMulInstruction(i, i + 1)); } - return arguments.stream(); } @@ -134,8 +138,8 @@ public Stream provideTinyArguments() { protected Stream provideNonRandomArguments() { List arguments = new ArrayList<>(); for (OpCode opCode : getModuleTracer().supportedOpCodes()) { - for (int k = 1; k <= 4; k++) { - for (int i = 1; i <= 4; i++) { + for (int k = 0; k <= 3; k++) { + for (int i = 0; i <= 3; i++) { arguments.add(Arguments.of(opCode, UInt256.valueOf(i), UInt256.valueOf(k))); } } @@ -143,6 +147,17 @@ protected Stream provideNonRandomArguments() { return arguments.stream(); } + protected Stream multiplyByZero() { + List arguments = new ArrayList<>(); + for (int i = 0; i < 2; i++) { + Bytes32 bytes1 = Bytes32.ZERO; + Bytes32 bytes2 = UInt256.valueOf(i); + arguments.add(Arguments.of(OpCode.MUL, bytes1, bytes2)); + arguments.add(Arguments.of(OpCode.MUL, bytes2, bytes1)); + } + return arguments.stream(); + } + @Override protected ModuleTracer getModuleTracer() { return new MulTracer(); diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java index 379b51afdf..bd8caac74c 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java @@ -4,12 +4,19 @@ import java.math.BigInteger; +import net.consensys.linea.zktracer.OpCode; +import net.consensys.linea.zktracer.bytes.Bytes16; +import net.consensys.linea.zktracer.bytes.UnsignedByte; +import net.consensys.linea.zktracer.module.Util; +import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; public class MulUtilsTest { @Test public void isTiny() { + // tiny means zero or one assertThat(MulData.isTiny(BigInteger.ZERO)).isTrue(); assertThat(MulData.isTiny(BigInteger.ONE)).isTrue(); assertThat(MulData.isTiny(BigInteger.TWO)).isFalse(); @@ -23,4 +30,39 @@ public void twoAdicity() { // assertThat(MulData.twoAdicity(UInt256.MAX_VALUE)).isEqualTo(0); // assertThat(MulData.twoAdicity(UInt256.valueOf(1))).isEqualTo(0); } + + @Test + public void multiplyByZero() { + Bytes32 arg1 = Bytes32.random(); + OpCode mul = OpCode.MUL; + MulData oxo = new MulData(mul, arg1, Bytes32.ZERO); + Assertions.assertThat(oxo.arg2Hi.isZero()).isTrue(); + Assertions.assertThat(oxo.arg2Lo).isEqualTo(Bytes16.ZERO); + Assertions.assertThat(oxo.arg2Hi).isEqualTo(Bytes16.ZERO); + assertThat(oxo.opCode).isEqualTo(mul); + assertThat(oxo.tinyExponent).isTrue(); + assertThat(oxo.isOneLineInstruction()).isTrue(); + assertThat(oxo.bits[0]).isFalse(); + } + + @Test + public void zeroExp() { + Bytes32 arg1 = Bytes32.random(); + OpCode mul = OpCode.EXP; + MulData oxo = new MulData(mul, arg1, Bytes32.ZERO); + Assertions.assertThat(oxo.arg2Hi.isZero()).isTrue(); + Assertions.assertThat(oxo.arg2Lo).isEqualTo(Bytes16.ZERO); + Assertions.assertThat(oxo.arg2Hi).isEqualTo(Bytes16.ZERO); + assertThat(oxo.opCode).isEqualTo(mul); + assertThat(oxo.tinyExponent).isTrue(); + assertThat(oxo.isOneLineInstruction()).isTrue(); + assertThat(oxo.bits[0]).isFalse(); + } + + @Test + public void testByteBits() { + Boolean[] booleans = Util.byteBits(UnsignedByte.of(0)); + assertThat(booleans.length).isEqualTo(8); + assertThat(booleans[0]).isNotNull(); + } } From 84db2a8e6b9187f3e8027891a902d2db6e919f44 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Thu, 20 Apr 2023 12:37:17 +1000 Subject: [PATCH 23/31] initialize bits[0,1] Signed-off-by: Sally MacFarlane --- .../net/consensys/linea/zktracer/module/alu/mul/MulData.java | 2 ++ .../consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 2eee5ae2bb..76e05f7aa0 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -356,6 +356,8 @@ public void setHsAndBits(UInt256 a, UInt256 b) { int mu = getOverflow(sum, 3, "mu OOB"); + bits[0] = false; + bits[1] = false; bits[2] = getBit(alpha, 0); bits[3] = getBit(beta, 0); bits[4] = getBit(beta, 1); diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java index bd8caac74c..644ef214cc 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java @@ -60,7 +60,7 @@ public void zeroExp() { } @Test - public void testByteBits() { + public void testByteBits_ofZero() { Boolean[] booleans = Util.byteBits(UnsignedByte.of(0)); assertThat(booleans.length).isEqualTo(8); assertThat(booleans[0]).isNotNull(); From 246860bc06070dbd46b65ea861b442b496004202 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Thu, 20 Apr 2023 14:13:09 +1000 Subject: [PATCH 24/31] add final zero to the zero Signed-off-by: Sally MacFarlane --- .../zktracer/module/alu/mul/MulTracer.java | 35 ++++++++++--------- .../module/alu/mul/MulTracerTest.java | 2 +- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java index be5f5afdba..4ac38dc899 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTracer.java @@ -46,40 +46,43 @@ public Object trace(MessageFrame frame) { final OpCode opCode = OpCode.of(frame.getCurrentOperation().getOpcode()); - final MulData data = new MulData(opCode, arg1, arg2); + // argument order is reversed ?? + final MulData data = new MulData(opCode, arg2, arg1); final MulTrace.Trace.Builder builder = MulTrace.Trace.Builder.newInstance(); - final int maxCt = data.maxCt(); - - stamp++; switch (data.getRegime()) { case EXPONENT_ZERO_RESULT -> { - for (int ct = 0; ct < maxCt; ct++) { - trace(builder, data, ct); - } - return builder.build(); + trace(builder, data); } case EXPONENT_NON_ZERO_RESULT -> { if (data.carryOn()) { data.update(); - for (int ct = 0; ct < maxCt; ct++) { - trace(builder, data, ct); - } + trace(builder, data); } - return builder.build(); } case TRIVIAL_MUL, NON_TRIVIAL_MUL -> { data.setHsAndBits(UInt256.fromBytes(arg1), UInt256.fromBytes(arg2)); - for (int ct = 0; ct < maxCt; ct++) { - trace(builder, data, ct); - } - return builder.build(); + trace(builder, data); } default -> throw new RuntimeException("regime not supported"); } + // TODO captureBlockEnd should be called from elsewhere - not within messageFrame + // captureBlockEnd(); + MulData finalZeroToTheZero = new MulData(OpCode.EXP, Bytes32.ZERO, Bytes32.ZERO); + trace(builder, finalZeroToTheZero); + + return builder.build(); + } + + private void trace(final MulTrace.Trace.Builder builder, final MulData data) { + stamp++; + + for (int ct = 0; ct < data.maxCt(); ct++) { + trace(builder, data, ct); + } } private void trace(final MulTrace.Trace.Builder builder, final MulData data, final int i) { diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java index 5fbc1dde61..4b08a0d6bf 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulTracerTest.java @@ -114,7 +114,7 @@ public Stream provideSpecificNonTinyArguments() { Bytes32.fromHexString("0x8a48aa20e200ce3fee16b5dcdec5c4faff613bc914d47cd6ca69553f8eb2b377"); payload[1] = Bytes32.fromHexString("0x59b635fec894caa3ed6817b1e67b3cbaeb8757fd6c7b03119b795303b7cd72c1"); - return Stream.of(Arguments.of(getRandomSupportedOpcode(), payload[0], payload[1])); + return Stream.of(Arguments.of(OpCode.MUL, payload[0], payload[1])); } public Stream provideRandomNonTinyArguments() { From 31e90d29f9411e7836a548aa788e0a93d0123d92 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Mon, 24 Apr 2023 10:01:36 +1000 Subject: [PATCH 25/31] OLI Signed-off-by: Sally MacFarlane --- .../net/consensys/linea/zktracer/module/alu/mul/MulTrace.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java index c13fa5a91e..18b7cb1cc7 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulTrace.java @@ -124,7 +124,7 @@ public record Trace( @JsonProperty("EXPONENT_BIT_SOURCE") List EXPONENT_BIT_SOURCE, @JsonProperty("INST") List INST, @JsonProperty("MUL_STAMP") List MUL_STAMP, - @JsonProperty("ONE_LINE_INSTRUCTION") List ONE_LINE_INSTRUCTION, + @JsonProperty("OLI") List ONE_LINE_INSTRUCTION, @JsonProperty("RESULT_VANISHES") List RESULT_VANISHES, @JsonProperty("RES_HI") List RES_HI, @JsonProperty("RES_LO") List RES_LO, From 14c2c7ee75892708950e753051e51bf16e9b0aa6 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Mon, 24 Apr 2023 10:09:44 +1000 Subject: [PATCH 26/31] fixed shiftLeft related bug Signed-off-by: Sally MacFarlane --- .../consensys/linea/zktracer/module/Util.java | 26 ++- .../zktracer/module/alu/mul/MulData.java | 45 ++--- .../zktracer/module/alu/mul/MulUtilsTest.java | 159 ++++++++++++++++++ 3 files changed, 200 insertions(+), 30 deletions(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/Util.java b/src/main/java/net/consensys/linea/zktracer/module/Util.java index 45ad63477c..3f24aff786 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/Util.java +++ b/src/main/java/net/consensys/linea/zktracer/module/Util.java @@ -35,20 +35,28 @@ public static Boolean[] byteBits(final UnsignedByte b) { return bits; } - public static int getOverflow(final UInt256 arg, final int maxVal, final String err) { - UInt256 shiftRight = arg.shiftRight(128); - if (shiftRight.compareTo(UInt64.MAX_VALUE.toBytes()) > 0) { - throw new RuntimeException("getOverflow expects a small high part"); + // in Go implementation this method modifies the arg param + // however (at least in MUL module) the modified value is never used + // so have not gone to any effort to recreate that behavior in Java implementation + public static UInt64 getOverflow(final UInt256 arg, final UInt64 maxVal, final String err) { + UInt256 shifted = arg.shiftRight(128); + if (shifted.compareTo(UInt64.MAX_VALUE.toBytes()) > 0) { + // in Go this is panic() but caught by the calling func + // throw new RuntimeException("getOverflow expects a small high part"); + return UInt64.ZERO; } - int overflow = shiftRight.intValue(); - if (overflow > maxVal) { - throw new RuntimeException(err); + + UInt64 overflow = UInt64.fromBytes(shifted.trimLeadingZeros()); + if (overflow.compareTo(maxVal) > 0) { + // in Go this is panic() but caught by the calling func + // throw new RuntimeException(err + " overflow=" + overflow); + return UInt64.ZERO; } return overflow; } // GetBit returns true iff the k'th bit of x is 1 - public static boolean getBit(int x, int k) { - return (x >> k) % 2 == 1; + public static boolean getBit(UInt64 x, int k) { + return (x.shiftRight(k)).mod(2).equals(UInt64.ONE); } } diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 76e05f7aa0..acbf9c6fc4 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -16,8 +16,8 @@ import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; +import org.apache.tuweni.units.bigints.UInt64; -@SuppressWarnings("UnusedVariable") public class MulData { private static final int MMEDIUM = 8; final OpCode opCode; @@ -110,7 +110,7 @@ private void setArraysForZeroResultCase() { bytes = aBytes.get(1); } else { for (int i = 0; i < 8; i++) { - cBytes.setBytes(0, ones); + cBytes.setChunk(0, ones); } bytes = aBytes.get(0); } @@ -164,7 +164,7 @@ private void setArraysForZeroResultCase() { } final BaseTheta thing = BaseTheta.fromBytes32(UInt256.valueOf(target)); - hBytes.setBytes(1, thing.get(0)); + hBytes.setChunk(1, thing.get(0)); } return; @@ -285,7 +285,6 @@ private int bitNum(int i, int length) { public void update() { final BigInteger arg1BigInt = UInt256.fromBytes(arg1).toUnsignedBigInteger(); - final BigInteger arg2BigInt = UInt256.fromBytes(arg2).toUnsignedBigInteger(); if (!snm) { // squaring setHsAndBits(resAcc, resAcc); @@ -302,10 +301,11 @@ public void update() { public void setHsAndBits(UInt256 a, UInt256 b) { - BaseTheta aBaseTheta = BaseTheta.fromBytes32(a); - BaseTheta bBaseTheta = BaseTheta.fromBytes32(b); - BaseTheta sumBaseTheta; + setHsAndBitsFromBaseThetas(BaseTheta.fromBytes32(a), BaseTheta.fromBytes32(b)); + } + public void setHsAndBitsFromBaseThetas(BaseTheta aBaseTheta, BaseTheta bBaseTheta) { + BaseTheta sumBaseTheta; UInt256[] aBaseThetaInts = (UInt256[]) Array.newInstance(UInt256.class, 4); UInt256[] bBaseThetaInts = (UInt256[]) Array.newInstance(UInt256.class, 4); @@ -321,9 +321,9 @@ public void setHsAndBits(UInt256 a, UInt256 b) { sum = sum.add(prod); // sum += a0 * b1 sumBaseTheta = BaseTheta.fromBytes32(sum); - hBytes.setBytes(0, sumBaseTheta.get(0)); - hBytes.setBytes(1, sumBaseTheta.get(1)); - int alpha = getOverflow(sum, 1, "alpha OOB"); + hBytes.setChunk(0, sumBaseTheta.get(0)); + hBytes.setChunk(1, sumBaseTheta.get(1)); + UInt64 alpha = getOverflow(sum, UInt64.ONE, "alpha OOB"); sum = aBaseThetaInts[3].multiply(bBaseThetaInts[0]); // sum := a3 * b0 prod = aBaseThetaInts[2].multiply(bBaseThetaInts[1]); @@ -334,27 +334,31 @@ public void setHsAndBits(UInt256 a, UInt256 b) { sum = sum.add(prod); // sum += a0 * b3 sumBaseTheta = BaseTheta.fromBytes32(sum); - hBytes.setBytes(2, sumBaseTheta.get(0)); - hBytes.setBytes(3, sumBaseTheta.get(1)); - int beta = getOverflow(sum, 3, "beta OOB"); + hBytes.setChunk(2, sumBaseTheta.get(0)); + hBytes.setChunk(3, sumBaseTheta.get(1)); + UInt64 beta = getOverflow(sum, UInt64.valueOf(3), "beta OOB"); + + prod = aBaseThetaInts[0].multiply(bBaseThetaInts[0]); + sum = sum.add(prod); // sum := a0 * b0 - sum = aBaseThetaInts[0].multiply(bBaseThetaInts[0]); // sum := a0 * b0 - sum = sum.add(UInt256.fromBytes(hBytes.get(0).shiftLeft(64))); // sum += (h0 << 64) + prod = UInt256.fromBytes(hBytes.get(0)).shiftLeft(64); + sum = sum.add(prod); // sum += (h0 << 64) - int eta = getOverflow(sum, 1, "eta OOB"); + UInt64 eta = getOverflow(sum, UInt64.ONE, "eta OOB"); - sum = UInt256.valueOf(eta); // sum := eta + sum = UInt256.fromBytes(eta.toBytes()); // sum := eta sum = sum.add(UInt256.fromBytes(hBytes.get(1))); // sum += h1 - sum = sum.add(UInt256.valueOf(alpha).shiftLeft(64)); // sum += (alpha << 64) + prod = UInt256.fromBytes(alpha.toBytes()).shiftLeft(64); + sum = sum.add(prod); // sum += (alpha << 64) prod = aBaseThetaInts[2].multiply(bBaseThetaInts[0]); sum = sum.add(prod); // sum += a2 * b0 prod = aBaseThetaInts[1].multiply(bBaseThetaInts[1]); sum = sum.add(prod); // sum += a1 * b1 prod = aBaseThetaInts[0].multiply(bBaseThetaInts[2]); sum = sum.add(prod); // sum += a0 * b2 - sum = sum.add(UInt256.fromBytes(hBytes.get(2).shiftLeft(64))); // sum += (h2 << 64) + sum = sum.add(UInt256.fromBytes(hBytes.get(2)).shiftLeft(64)); // sum += (h2 << 64) - int mu = getOverflow(sum, 3, "mu OOB"); + UInt64 mu = getOverflow(sum, UInt64.valueOf(3), "mu OOB"); bits[0] = false; bits[1] = false; @@ -364,7 +368,6 @@ public void setHsAndBits(UInt256 a, UInt256 b) { bits[5] = getBit(eta, 0); bits[6] = getBit(mu, 0); bits[7] = getBit(mu, 1); - return; } diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java index 644ef214cc..76686d0c41 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java @@ -7,9 +7,12 @@ import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.bytes.Bytes16; import net.consensys.linea.zktracer.bytes.UnsignedByte; +import net.consensys.linea.zktracer.bytestheta.BaseTheta; import net.consensys.linea.zktracer.module.Util; +import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; +import org.apache.tuweni.units.bigints.UInt64; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; @@ -65,4 +68,160 @@ public void testByteBits_ofZero() { assertThat(booleans.length).isEqualTo(8); assertThat(booleans[0]).isNotNull(); } + + @Test + public void hBytesAllZeros() { + Bytes32 arg1 = Bytes32.ZERO; + Bytes32 arg2 = Bytes32.ZERO; + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas( + BaseTheta.fromBytes32(UInt256.ZERO), BaseTheta.fromBytes32(UInt256.ZERO)); + assertThat(mulData.hBytes.get(0).isZero()).isTrue(); + assertThat(mulData.hBytes.get(1).isZero()).isTrue(); + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + + assertThat(mulData.hBytes.get(0).shiftLeft(64)).isEqualTo(mulData.hBytes.get(1)); // ZERO + } + + @Test + public void hBytesWhereOneArgIsZero() { + Bytes32 arg1 = Bytes32.ZERO; + Bytes32 arg2 = Bytes32.ZERO; + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas( + BaseTheta.fromBytes32(UInt256.ZERO), BaseTheta.fromBytes32(UInt256.valueOf(1))); + assertThat(mulData.hBytes.get(0).isZero()).isTrue(); + assertThat(mulData.hBytes.get(1).isZero()).isTrue(); + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + + assertThat(mulData.hBytes.get(0).shiftLeft(64)).isEqualTo(mulData.hBytes.get(1)); // ZERO + } + + @Test + public void hBytes_5_5() { + Bytes32 arg1 = Bytes32.fromHexString("0x05"); + Bytes32 arg2 = Bytes32.fromHexString("0x05"); + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas(BaseTheta.fromBytes32(arg1), BaseTheta.fromBytes32(arg2)); + assertThat(mulData.hBytes.get(0).isZero()).isTrue(); + assertThat(mulData.hBytes.get(1).isZero()).isTrue(); + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + } + + @Test + public void hBytes_largeEnoughArgsToGetNonZeros() { + + // these args aren't used directly in the calculation of hs and bits + Bytes32 arg1 = Bytes32.fromHexString("0x05"); + Bytes32 arg2 = Bytes32.fromHexString("0x05"); + + BaseTheta aBaseTheta = BaseTheta.fromBytes32(UInt256.valueOf(43532)); // aa 0c // 170 12 + + // squaring is one way to get a sufficiently big number + BigInteger b = BigInteger.valueOf(82494664664768L); + BigInteger b2 = b.multiply(b); // 6805369698152522037820493824 + UInt256 b2uint = UInt256.valueOf(b2); + + BaseTheta bBaseTheta = BaseTheta.fromBytes32(b2uint); + + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas(aBaseTheta, bBaseTheta); + + Bytes h0 = Bytes.fromHexString("0x00000e9b37bfd908"); + + assertThat(mulData.hBytes.get(0)).isEqualTo(h0); + assertThat(mulData.hBytes.get(1).isZero()).isTrue(); + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + } + + @Test + public void hBytes_aReallyBigNumber() { + + // these args aren't used directly in the calculation of hs and bits + Bytes32 arg1 = Bytes32.fromHexString("0x05"); + Bytes32 arg2 = Bytes32.fromHexString("0x05"); + + BigInteger b = new BigInteger("296251353699975589350401737146368"); + UInt256 b4uint = UInt256.valueOf(b); + + BaseTheta aBaseTheta = BaseTheta.fromBytes32(b4uint); + BaseTheta bBaseTheta = BaseTheta.fromBytes32(b4uint); + + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas(aBaseTheta, bBaseTheta); + + Bytes h0 = Bytes.fromHexString("0xe9e4064d86460000"); + Bytes h1 = Bytes.fromHexString("0x00000779079ae9e3"); + + assertThat(mulData.hBytes.get(0)).isEqualTo(h0); + assertThat(mulData.hBytes.get(1)).isEqualTo(h1); + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + } + + @Test + public void hBytesValue_and_generatesATrueBit() { + + // these args aren't used directly in the calculation of hs and bits + Bytes32 arg1 = Bytes32.fromHexString("0x05"); + Bytes32 arg2 = Bytes32.fromHexString("0x05"); + + BigInteger b = new BigInteger("3672491014949214151879080813"); + UInt256 b4uint = UInt256.valueOf(b); + + BaseTheta aBaseTheta = BaseTheta.fromBytes32(b4uint); + BaseTheta bBaseTheta = BaseTheta.fromBytes32(b4uint); + + MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas(aBaseTheta, bBaseTheta); + + Bytes h0 = Bytes.fromHexString("0x8d222c056a351fb0"); + Bytes h1 = Bytes.fromHexString("0x0000000013fc1cf0"); + + assertThat(mulData.hBytes.get(2).isZero()).isTrue(); + assertThat(mulData.hBytes.get(3).isZero()).isTrue(); + assertThat(mulData.hBytes.get(0)).isEqualTo(h0); + assertThat(mulData.hBytes.get(1)).isEqualTo(h1); + + // bits + // expected value obtained from go implementation debug output + Boolean[] expectedBools = {false, false, false, false, false, true, false, false}; + assertThat(mulData.bits).isEqualTo(expectedBools); + } + + @Test + public void hBytes_twoReallyBigNumbers_generatesADifferentTrueBit() { + + // these args aren't used directly in the calculation of hs and bits + Bytes32 arg1 = Bytes32.fromHexString("0x05"); + Bytes32 arg2 = Bytes32.fromHexString("0x05"); + + BigInteger b1 = + new BigInteger( + "22469423347992668196557015132986860313508181747976369840918221307635594854917"); + UInt256 b1uint = UInt256.valueOf(b1); + BigInteger b2 = + new BigInteger( + "24978870742348927442211038043709369780953629326985996012322553405261539030097"); + UInt256 b2uint = UInt256.valueOf(b2); + + BaseTheta aBaseTheta = BaseTheta.fromBytes32(b1uint); + BaseTheta bBaseTheta = BaseTheta.fromBytes32(b2uint); + + final MulData mulData = new MulData(OpCode.EXP, arg1, arg2); + mulData.setHsAndBitsFromBaseThetas(aBaseTheta, bBaseTheta); + + BigInteger sum_010 = new BigInteger("375860551383434850958895718584879559103"); + assertThat(Util.getOverflow(UInt256.valueOf(sum_010), UInt64.valueOf(3), "mu OOB")) + .isEqualTo(UInt64.ONE); + + // bits + // expected value obtained from go implementation debug output + Boolean[] expectedBools = {false, false, false, false, false, false, true, false}; + assertThat(mulData.bits).isEqualTo(expectedBools); + } } From 06179c9c1e9b6de78a7f4f91f984e92aea9aec03 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 26 Apr 2023 13:00:32 +1000 Subject: [PATCH 27/31] throw exception if wrong opcode Signed-off-by: Sally MacFarlane --- .../net/consensys/linea/zktracer/module/alu/mul/MulData.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 0bc957ce81..af89313be6 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -91,7 +91,7 @@ private static BaseBytes getRes(OpCode opCode, Bytes32 arg1, Bytes32 arg2) { return switch (opCode) { case MUL -> BaseBytes.fromBytes32(UInt256.fromBytes(arg1).multiply(UInt256.fromBytes(arg2))); case EXP -> BaseBytes.fromBytes32(UInt256.fromBytes(arg1).pow(UInt256.fromBytes(arg2))); - default -> BaseBytes.fromBytes32(UInt256.ZERO); + default -> throw new RuntimeException("MUL module was given wrong opcode"); }; } From ccbdd7b025296c824d4ac78b824a88fe30948300 Mon Sep 17 00:00:00 2001 From: Sally MacFarlane Date: Wed, 26 Apr 2023 13:04:15 +1000 Subject: [PATCH 28/31] fixed assignment of sum Signed-off-by: Sally MacFarlane --- .../net/consensys/linea/zktracer/module/alu/mul/MulData.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index af89313be6..847c6fb215 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -338,7 +338,7 @@ public void setHsAndBitsFromBaseThetas(BaseTheta aBaseTheta, BaseTheta bBaseThet long beta = getOverflow(sum, 3, "beta OOB"); prod = aBaseThetaInts[0].multiply(bBaseThetaInts[0]); - sum = sum.add(prod); // sum := a0 * b0 + sum = prod; // sum := a0 * b0 prod = UInt256.fromBytes(hBytes.get(0)).shiftLeft(64); sum = sum.add(prod); // sum += (h0 << 64) From 1bfd5de877e9ecc458e75f48f4528273467ac183 Mon Sep 17 00:00:00 2001 From: Gabriel-Trintinalia Date: Wed, 26 Apr 2023 14:26:58 +1000 Subject: [PATCH 29/31] Move ext test to the right package --- .../zktracer/{ => corset}/module/alu/ext/ExtTracerTest.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) rename src/test/java/net/consensys/linea/zktracer/{ => corset}/module/alu/ext/ExtTracerTest.java (97%) diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java b/src/test/java/net/consensys/linea/zktracer/corset/module/alu/ext/ExtTracerTest.java similarity index 97% rename from src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java rename to src/test/java/net/consensys/linea/zktracer/corset/module/alu/ext/ExtTracerTest.java index 8170e01673..b5009fd0e4 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/corset/module/alu/ext/ExtTracerTest.java @@ -12,7 +12,7 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.module.alu.ext; +package net.consensys.linea.zktracer.corset.module.alu.ext; import java.util.ArrayList; import java.util.List; @@ -22,6 +22,7 @@ import net.consensys.linea.zktracer.AbstractModuleTracerCorsetTest; import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.module.ModuleTracer; +import net.consensys.linea.zktracer.module.alu.ext.ExtTracer; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; import org.junit.jupiter.api.Assertions; From 36212fcb85b6361b703f35ce1820cb96f2762b2a Mon Sep 17 00:00:00 2001 From: Gabriel-Trintinalia Date: Wed, 26 Apr 2023 14:37:35 +1000 Subject: [PATCH 30/31] Return modules to previous package --- .../zktracer/{corset => }/module/alu/add/AddTracerTest.java | 2 +- .../linea/zktracer/{corset => }/module/alu/add/AdderTest.java | 2 +- .../zktracer/{corset => }/module/alu/ext/ExtTracerTest.java | 2 +- .../zktracer/{corset => }/module/alu/mod/ModTracerTest.java | 2 +- .../linea/zktracer/{corset => }/module/shf/ShfTracerTest.java | 3 +-- .../linea/zktracer/{corset => }/module/wcp/WcpTracerTest.java | 2 +- 6 files changed, 6 insertions(+), 7 deletions(-) rename src/test/java/net/consensys/linea/zktracer/{corset => }/module/alu/add/AddTracerTest.java (98%) rename src/test/java/net/consensys/linea/zktracer/{corset => }/module/alu/add/AdderTest.java (97%) rename src/test/java/net/consensys/linea/zktracer/{corset => }/module/alu/ext/ExtTracerTest.java (98%) rename src/test/java/net/consensys/linea/zktracer/{corset => }/module/alu/mod/ModTracerTest.java (98%) rename src/test/java/net/consensys/linea/zktracer/{corset => }/module/shf/ShfTracerTest.java (97%) rename src/test/java/net/consensys/linea/zktracer/{corset => }/module/wcp/WcpTracerTest.java (98%) diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AddTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AddTracerTest.java similarity index 98% rename from src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AddTracerTest.java rename to src/test/java/net/consensys/linea/zktracer/module/alu/add/AddTracerTest.java index dd4a384eb3..3744d0623f 100644 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AddTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AddTracerTest.java @@ -12,7 +12,7 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.corset.module.alu.add; +package net.consensys.linea.zktracer.module.alu.add; import java.util.ArrayList; import java.util.List; diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AdderTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AdderTest.java similarity index 97% rename from src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AdderTest.java rename to src/test/java/net/consensys/linea/zktracer/module/alu/add/AdderTest.java index 3b28987e43..b0cae0758a 100644 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/add/AdderTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AdderTest.java @@ -12,7 +12,7 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.corset.module.alu.add; +package net.consensys.linea.zktracer.module.alu.add; import static org.assertj.core.api.Assertions.assertThat; diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/ext/ExtTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java similarity index 98% rename from src/test/java/net/consensys/linea/zktracer/corset/module/alu/ext/ExtTracerTest.java rename to src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java index b5009fd0e4..0f174d6dba 100644 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/ext/ExtTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java @@ -12,7 +12,7 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.corset.module.alu.ext; +package net.consensys.linea.zktracer.module.alu.ext; import java.util.ArrayList; import java.util.List; diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/mod/ModTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mod/ModTracerTest.java similarity index 98% rename from src/test/java/net/consensys/linea/zktracer/corset/module/alu/mod/ModTracerTest.java rename to src/test/java/net/consensys/linea/zktracer/module/alu/mod/ModTracerTest.java index 4c1a84a0f7..5da8d66b99 100644 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/alu/mod/ModTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mod/ModTracerTest.java @@ -12,7 +12,7 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.corset.module.alu.mod; +package net.consensys.linea.zktracer.module.alu.mod; import java.math.BigInteger; import java.util.ArrayList; diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/shf/ShfTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/shf/ShfTracerTest.java similarity index 97% rename from src/test/java/net/consensys/linea/zktracer/corset/module/shf/ShfTracerTest.java rename to src/test/java/net/consensys/linea/zktracer/module/shf/ShfTracerTest.java index ac0bcd7bc5..1d3609fa83 100644 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/shf/ShfTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/shf/ShfTracerTest.java @@ -12,7 +12,7 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.corset.module.shf; +package net.consensys.linea.zktracer.module.shf; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.mockito.Mockito.when; @@ -28,7 +28,6 @@ import net.consensys.linea.zktracer.ZkTraceBuilder; import net.consensys.linea.zktracer.ZkTracer; import net.consensys.linea.zktracer.corset.CorsetValidator; -import net.consensys.linea.zktracer.module.shf.ShfTracer; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.junit.jupiter.api.BeforeEach; diff --git a/src/test/java/net/consensys/linea/zktracer/corset/module/wcp/WcpTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/wcp/WcpTracerTest.java similarity index 98% rename from src/test/java/net/consensys/linea/zktracer/corset/module/wcp/WcpTracerTest.java rename to src/test/java/net/consensys/linea/zktracer/module/wcp/WcpTracerTest.java index 038bde857e..337c95daae 100644 --- a/src/test/java/net/consensys/linea/zktracer/corset/module/wcp/WcpTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/wcp/WcpTracerTest.java @@ -12,7 +12,7 @@ * * SPDX-License-Identifier: Apache-2.0 */ -package net.consensys.linea.zktracer.corset.module.wcp; +package net.consensys.linea.zktracer.module.wcp; import static net.consensys.linea.zktracer.OpCode.SGT; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; From 76b2a3954445fed42b124eee4cb19331cf1cca70 Mon Sep 17 00:00:00 2001 From: Gabriel-Trintinalia Date: Wed, 26 Apr 2023 16:43:27 +1000 Subject: [PATCH 31/31] Make ext and mod test classes extend abstract class --- build.gradle | 12 +++ .../consensys/linea/zktracer/module/Util.java | 2 +- .../zktracer/module/alu/mul/MulData.java | 14 ++++ .../AbstractModuleTracerCorsetTest.java | 2 + .../module/alu/add/AddTracerTest.java | 1 - .../zktracer/module/alu/add/AdderTest.java | 1 - .../module/alu/ext/ExtTracerTest.java | 1 - .../module/alu/mod/ModTracerTest.java | 1 - .../zktracer/module/alu/mul/MulUtilsTest.java | 14 ++++ .../zktracer/module/shf/ShfTracerTest.java | 2 + .../zktracer/module/wcp/WcpTracerTest.java | 75 +++---------------- 11 files changed, 55 insertions(+), 70 deletions(-) diff --git a/build.gradle b/build.gradle index 09dd4c0f35..8a70608865 100644 --- a/build.gradle +++ b/build.gradle @@ -215,6 +215,18 @@ allprojects { useJUnitPlatform() } + task unitTests(type: Test) { + useJUnitPlatform { + excludeTags("CorsetTest") + } + } + + task corsetTests(type: Test) { + useJUnitPlatform { + includeTags("CorsetTest") + } + } + javadoc { options.addStringOption('Xdoclint:all', '-quiet') options.addStringOption('Xwerror', '-html5') diff --git a/src/main/java/net/consensys/linea/zktracer/module/Util.java b/src/main/java/net/consensys/linea/zktracer/module/Util.java index 37d41173ca..f97153ff01 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/Util.java +++ b/src/main/java/net/consensys/linea/zktracer/module/Util.java @@ -119,7 +119,7 @@ public static UInt256 multiplyRange(Bytes[] range1, Bytes[] range2) { UInt256 sum = UInt256.ZERO; for (int i = 0; i < range1.length; i++) { UInt256 prod = - UInt256.fromBytes(range1[i]).multiply(UInt256.fromBytes(range2[range2.length - i - 1])); + UInt256.fromBytes(range1[i]).multiply(UInt256.fromBytes(range2[range2.length - i - 1])); sum = sum.add(prod); } return sum; diff --git a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java index 847c6fb215..d512c0264a 100644 --- a/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java +++ b/src/main/java/net/consensys/linea/zktracer/module/alu/mul/MulData.java @@ -1,3 +1,17 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ package net.consensys.linea.zktracer.module.alu.mul; import static net.consensys.linea.zktracer.module.Util.boolToByte; diff --git a/src/test/java/net/consensys/linea/zktracer/AbstractModuleTracerCorsetTest.java b/src/test/java/net/consensys/linea/zktracer/AbstractModuleTracerCorsetTest.java index 3a1ae8bae0..913fd42d91 100644 --- a/src/test/java/net/consensys/linea/zktracer/AbstractModuleTracerCorsetTest.java +++ b/src/test/java/net/consensys/linea/zktracer/AbstractModuleTracerCorsetTest.java @@ -20,6 +20,7 @@ import java.util.stream.Stream; import org.apache.tuweni.bytes.Bytes32; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; import org.junit.jupiter.params.ParameterizedTest; @@ -27,6 +28,7 @@ import org.junit.jupiter.params.provider.MethodSource; @TestInstance(Lifecycle.PER_CLASS) +@Tag("CorsetTest") public abstract class AbstractModuleTracerCorsetTest extends AbstractBaseModuleTracerTest { static final Random rand = new Random(); private static final int TEST_REPETITIONS = 8; diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/add/AddTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AddTracerTest.java index 3744d0623f..066b66eaf4 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/add/AddTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AddTracerTest.java @@ -22,7 +22,6 @@ import net.consensys.linea.zktracer.AbstractModuleTracerCorsetTest; import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.module.ModuleTracer; -import net.consensys.linea.zktracer.module.alu.add.AddTracer; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/add/AdderTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AdderTest.java index b0cae0758a..590c74bb96 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/add/AdderTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/add/AdderTest.java @@ -18,7 +18,6 @@ import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.bytestheta.BaseBytes; -import net.consensys.linea.zktracer.module.alu.add.Adder; import org.apache.tuweni.bytes.Bytes32; import org.junit.jupiter.api.Test; diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java index 0f174d6dba..8170e01673 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java @@ -22,7 +22,6 @@ import net.consensys.linea.zktracer.AbstractModuleTracerCorsetTest; import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.module.ModuleTracer; -import net.consensys.linea.zktracer.module.alu.ext.ExtTracer; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; import org.junit.jupiter.api.Assertions; diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mod/ModTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mod/ModTracerTest.java index 5da8d66b99..b046a13bf5 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mod/ModTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mod/ModTracerTest.java @@ -23,7 +23,6 @@ import net.consensys.linea.zktracer.AbstractModuleTracerCorsetTest; import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.module.ModuleTracer; -import net.consensys.linea.zktracer.module.alu.mod.ModTracer; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java index 702c646c62..2671110630 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/mul/MulUtilsTest.java @@ -1,3 +1,17 @@ +/* + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ package net.consensys.linea.zktracer.module.alu.mul; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; diff --git a/src/test/java/net/consensys/linea/zktracer/module/shf/ShfTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/shf/ShfTracerTest.java index 1d3609fa83..5b24a645b0 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/shf/ShfTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/shf/ShfTracerTest.java @@ -32,6 +32,7 @@ import org.apache.tuweni.bytes.Bytes32; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Named; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; @@ -43,6 +44,7 @@ import org.slf4j.LoggerFactory; @ExtendWith(MockitoExtension.class) +@Tag("CorsetTest") class ShfTracerTest { private static final Logger LOG = LoggerFactory.getLogger(ShfTracerTest.class); diff --git a/src/test/java/net/consensys/linea/zktracer/module/wcp/WcpTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/wcp/WcpTracerTest.java index 337c95daae..3a3e78c1aa 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/wcp/WcpTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/wcp/WcpTracerTest.java @@ -14,87 +14,32 @@ */ package net.consensys.linea.zktracer.module.wcp; -import static net.consensys.linea.zktracer.OpCode.SGT; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.mockito.Mockito.when; +import static net.consensys.linea.zktracer.OpCode.GT; -import org.hyperledger.besu.evm.frame.MessageFrame; -import org.hyperledger.besu.evm.operation.Operation; - -import java.util.ArrayList; import java.util.List; -import java.util.Random; import java.util.stream.Stream; -import net.consensys.linea.zktracer.OpCode; -import net.consensys.linea.zktracer.ZkTraceBuilder; -import net.consensys.linea.zktracer.ZkTracer; -import net.consensys.linea.zktracer.corset.CorsetValidator; -import net.consensys.linea.zktracer.module.wcp.WcpTracer; +import net.consensys.linea.zktracer.AbstractModuleTracerCorsetTest; +import net.consensys.linea.zktracer.module.ModuleTracer; import org.apache.tuweni.bytes.Bytes32; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) -class WcpTracerTest { - private ZkTracer zkTracer; - private ZkTraceBuilder zkTraceBuilder; - @Mock MessageFrame mockFrame; - @Mock Operation mockOperation; - - private static final Random rand = new Random(); - private static final int TEST_REPETITIONS = 4; - - @BeforeEach - void setUp() { - zkTraceBuilder = new ZkTraceBuilder(); - zkTracer = new ZkTracer(zkTraceBuilder, List.of(new WcpTracer())); - when(mockFrame.getCurrentOperation()).thenReturn(mockOperation); - } +class WcpTracerTest extends AbstractModuleTracerCorsetTest { - @ParameterizedTest() - @MethodSource("provideRandomArguments") - void testRandomWcp(OpCode opCode, final Bytes32 arg1, Bytes32 arg2) { - when(mockOperation.getOpcode()).thenReturn((int) opCode.value); - when(mockFrame.getStackItem(0)).thenReturn(arg1); - when(mockFrame.getStackItem(1)).thenReturn(arg2); - zkTracer.tracePreExecution(mockFrame); - assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + @Override + protected ModuleTracer getModuleTracer() { + return new WcpTracer(); } - @Test - public void testNonRandomWcp() { + @Override + protected Stream provideNonRandomArguments() { Bytes32 arg1 = Bytes32.fromHexString("0xdcd5cf52e4daec5389587d0d0e996e6ce2d0546b63d3ea0a0dc48ad984d180a9"); Bytes32 arg2 = Bytes32.fromHexString("0x0479484af4a59464a48818b3980174687661bafb13d06f49537995fa6c02159e"); - traceOperation(SGT, arg1, arg2); - } - - public static Stream provideRandomArguments() { - final List arguments = new ArrayList<>(); - for (OpCode opCode : new WcpTracer().supportedOpCodes()) { - for (int i = 0; i <= TEST_REPETITIONS; i++) { - Bytes32[] payload = new Bytes32[2]; - payload[0] = Bytes32.random(rand); - payload[1] = Bytes32.random(rand); - arguments.add(Arguments.of(opCode, payload[0], payload[1])); - } - } - return arguments.stream(); - } - - private void traceOperation(OpCode opcode, Bytes32 arg1, Bytes32 arg2) { - when(mockOperation.getOpcode()).thenReturn((int) opcode.value); - when(mockFrame.getStackItem(0)).thenReturn(arg1); - when(mockFrame.getStackItem(1)).thenReturn(arg2); - zkTracer.tracePreExecution(mockFrame); - assertThat(CorsetValidator.isValid(zkTraceBuilder.build().toJson())).isTrue(); + return Stream.of(Arguments.of(GT, List.of(arg1, arg2))); } }