diff --git a/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java b/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java index 4f0bdb7f14..9665584846 100644 --- a/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java +++ b/src/test/java/net/consensys/linea/zktracer/module/alu/ext/ExtTracerTest.java @@ -14,7 +14,6 @@ */ package net.consensys.linea.zktracer.module.alu.ext; -import java.math.BigInteger; import java.util.ArrayList; import java.util.List; import java.util.Random; @@ -23,17 +22,18 @@ import net.consensys.linea.zktracer.OpCode; import net.consensys.linea.zktracer.module.AbstractModuleTracerTest; import net.consensys.linea.zktracer.module.ModuleTracer; -import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.apache.tuweni.units.bigints.UInt256; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) class ExtTracerTest extends AbstractModuleTracerTest { + static final Random rand = new Random(); @Override @@ -64,74 +64,71 @@ public Stream provideNonRandomArguments() { return arguments.stream(); } - @Test - public void argumentZeroValueTestMulModTest() { - Bytes32 arg1 = fromBigInteger(BigInteger.valueOf(0)); - Bytes32 arg2 = fromBigInteger(BigInteger.valueOf(7)); - Bytes32 arg3 = fromBigInteger(BigInteger.valueOf(13)); - runTest(OpCode.MULMOD, List.of(arg1, arg2, arg3)); + @ParameterizedTest() + @MethodSource("provideZeroValueTest") + public void argumentZeroValueTestMulModTest(final OpCode opCode, final List arguments) { + runTest(opCode, arguments); } - @Test - public void argumentZeroValueTestAddModTest() { - Bytes32 arg1 = fromBigInteger(BigInteger.valueOf(0)); - Bytes32 arg2 = fromBigInteger(BigInteger.valueOf(7)); - Bytes32 arg3 = fromBigInteger(BigInteger.valueOf(13)); - runTest(OpCode.ADDMOD, List.of(arg1, arg2, arg3)); + @ParameterizedTest() + @MethodSource("provideModulusZeroValueArguments") + public void modulusZeroValueTestMulModTest(final OpCode opCode, final List arguments) { + Assertions.assertThrows(ArithmeticException.class, () -> runTest(opCode, arguments)); } - @Test - public void modulusZeroValueTestMulModTest() { - Bytes32 arg1 = fromBigInteger(BigInteger.valueOf(1)); - Bytes32 arg2 = fromBigInteger(BigInteger.valueOf(1)); - Bytes32 arg3 = fromBigInteger(BigInteger.valueOf(0)); - - Assertions.assertThrows( - ArithmeticException.class, () -> runTest(OpCode.MULMOD, List.of(arg1, arg2, arg3))); + @ParameterizedTest() + @MethodSource("provideTinyValueArguments") + public void tinyValueTest(final OpCode opCode, final List arguments) { + runTest(opCode, arguments); } - @Test - public void modulusZeroValueTestAddModTest() { - Bytes32 arg1 = fromBigInteger(BigInteger.valueOf(1)); - Bytes32 arg2 = fromBigInteger(BigInteger.valueOf(1)); - Bytes32 arg3 = fromBigInteger(BigInteger.valueOf(0)); - Assertions.assertThrows( - ArithmeticException.class, () -> runTest(OpCode.MULMOD, List.of(arg1, arg2, arg3))); + @ParameterizedTest() + @MethodSource("provideMaxValueArguments") + public void maxValueTest(final OpCode opCode, final List arguments) { + runTest(opCode, arguments); } - @Test - public void tinyValueTest() { - Bytes32 arg1 = fromBigInteger(BigInteger.valueOf(6)); - Bytes32 arg2 = fromBigInteger(BigInteger.valueOf(7)); - Bytes32 arg3 = fromBigInteger(BigInteger.valueOf(13)); - runTest(OpCode.MULMOD, List.of(arg1, arg2, arg3)); + @Override + protected ModuleTracer getModuleTracer() { + return new ExtTracer(); } - @Test - public void maxExactValueTest() { - Bytes32 arg1 = UInt256.MAX_VALUE; - Bytes32 arg2 = UInt256.MAX_VALUE; - Bytes32 arg3 = UInt256.MAX_VALUE; - runTest(OpCode.MULMOD, List.of(arg1, arg2, arg3)); + public Stream provideZeroValueTest() { + List arguments = new ArrayList<>(); + for (OpCode opCode : getModuleTracer().supportedOpCodes()) { + arguments.add( + Arguments.of( + opCode, List.of(UInt256.valueOf(0), UInt256.valueOf(12), UInt256.valueOf(6)))); + } + return arguments.stream(); } - @Test - public void largeValueTest() { - Bytes32 arg1 = - UInt256.fromHexString("0xcb694eaa08d8cb30a26a74edc8ced2cc0d7f453c6df96307bf3d9336784aba26"); - Bytes32 arg2 = - UInt256.fromHexString("0xb1f3d8555ff1d8e1d1db41eb8640cdc0b5dc1ea19a87bd0cb046b634ab707409"); - Bytes32 arg3 = - UInt256.fromHexString("0x07d761cc7e0bf9770db9d952e5b108c96e6c3f0526218d2bfbef3071b0d776b8"); - runTest(OpCode.ADDMOD, List.of(arg1, arg2, arg3)); + public Stream provideModulusZeroValueArguments() { + List arguments = new ArrayList<>(); + for (OpCode opCode : getModuleTracer().supportedOpCodes()) { + arguments.add( + Arguments.of( + opCode, List.of(UInt256.valueOf(1), UInt256.valueOf(1), UInt256.valueOf(0)))); + } + return arguments.stream(); } - @Override - protected ModuleTracer getModuleTracer() { - return new ExtTracer(); + public Stream provideTinyValueArguments() { + List arguments = new ArrayList<>(); + for (OpCode opCode : getModuleTracer().supportedOpCodes()) { + arguments.add( + Arguments.of( + opCode, List.of(UInt256.valueOf(6), UInt256.valueOf(7), UInt256.valueOf(13)))); + } + return arguments.stream(); } - private Bytes32 fromBigInteger(BigInteger bigInteger) { - return Bytes32.leftPad(Bytes.wrap(bigInteger.toByteArray())); + public Stream provideMaxValueArguments() { + List arguments = new ArrayList<>(); + for (OpCode opCode : getModuleTracer().supportedOpCodes()) { + arguments.add( + Arguments.of(opCode, List.of(UInt256.MAX_VALUE, UInt256.MAX_VALUE, UInt256.MAX_VALUE))); + } + return arguments.stream(); } }