From 39077d6ca9cf2bcfac8939fcd87ba26186a2acce Mon Sep 17 00:00:00 2001 From: ahamlat Date: Fri, 11 Oct 2024 10:49:17 +0200 Subject: [PATCH] Improve StackItem memory footprint (#1390) * Improve StackOperation memory footprint by changing height from int to short Signed-off-by: Ameziane H --- .../fragment/common/CommonFragmentValues.java | 4 +- .../module/hub/section/TraceSection.java | 6 +- .../linea/zktracer/runtime/stack/Stack.java | 131 +++++++++++------- .../zktracer/runtime/stack/StackItem.java | 10 +- 4 files changed, 93 insertions(+), 58 deletions(-) diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/hub/fragment/common/CommonFragmentValues.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/hub/fragment/common/CommonFragmentValues.java index 5b171afcce..3310cc17c3 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/hub/fragment/common/CommonFragmentValues.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/hub/fragment/common/CommonFragmentValues.java @@ -81,8 +81,8 @@ public CommonFragmentValues(Hub hub) { // this.contextNumberNew = hub.contextNumberNew(callFrame); this.pc = hubProcessingPhase == TX_EXEC ? hub.currentFrame().pc() : 0; this.pcNew = computePcNew(hub, pc, noStackException, hub.state.getProcessingPhase() == TX_EXEC); - this.height = (short) callFrame.stack().getHeight(); - this.heightNew = (short) callFrame.stack().getHeightNew(); + this.height = callFrame.stack().getHeight(); + this.heightNew = callFrame.stack().getHeightNew(); // TODO: partial solution, will not work in general this.gasExpected = hub.expectedGas(); diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/hub/section/TraceSection.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/hub/section/TraceSection.java index d29505f630..ee180b48c6 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/hub/section/TraceSection.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/hub/section/TraceSection.java @@ -37,6 +37,7 @@ import net.consensys.linea.zktracer.module.hub.fragment.common.CommonFragment; import net.consensys.linea.zktracer.module.hub.fragment.common.CommonFragmentValues; import net.consensys.linea.zktracer.runtime.callstack.CallFrame; +import net.consensys.linea.zktracer.runtime.stack.Stack; import net.consensys.linea.zktracer.runtime.stack.StackLine; import org.apache.tuweni.bytes.Bytes; import org.hyperledger.besu.evm.internal.Words; @@ -158,12 +159,13 @@ private int computeContextNumberNew() { private List makeStackFragments(final Hub hub, CallFrame f) { final List r = new ArrayList<>(2); + Stack snapshot = f.stack().snapshot(); if (f.pending().lines().isEmpty()) { for (int i = 0; i < (f.opCodeData().stackSettings().twoLineInstruction() ? 2 : 1); i++) { r.add( StackFragment.prepare( hub, - f.stack().snapshot(), + snapshot, new StackLine().asStackItems(), hub.pch().exceptions(), hub.pch().abortingConditions().snapshot(), @@ -177,7 +179,7 @@ private List makeStackFragments(final Hub hub, CallFrame f) { r.add( StackFragment.prepare( hub, - f.stack().snapshot(), + snapshot, line.asStackItems(), hub.pch().exceptions(), hub.pch().abortingConditions().snapshot(), diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/runtime/stack/Stack.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/runtime/stack/Stack.java index 5bfc5d38aa..c587df290d 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/runtime/stack/Stack.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/runtime/stack/Stack.java @@ -31,8 +31,8 @@ public class Stack { public static final byte PUSH = 1; public static final byte POP = 2; - @Getter int height; - @Getter int heightNew; + @Getter short height; + @Getter short heightNew; @Getter OpCodeData currentOpcodeData; Status status; int stamp; @@ -78,7 +78,8 @@ private void twoZero(MessageFrame frame, StackContext pending) { private void zeroOne(MessageFrame ignoredFrame, StackContext pending) { pending.addArmingLine( - new IndexedStackOperation(4, StackItem.push(height + 1, stackStampWithOffset(0)))); + new IndexedStackOperation( + 4, StackItem.push((short) (height + 1), stackStampWithOffset(0)))); } private void oneOne(MessageFrame frame, StackContext pending) { @@ -95,8 +96,10 @@ private void twoOne(MessageFrame frame, StackContext pending) { pending.addArmingLine( new IndexedStackOperation(1, StackItem.pop(height, val1, stackStampWithOffset(0))), - new IndexedStackOperation(2, StackItem.pop(height - 1, val2, stackStampWithOffset(1))), - new IndexedStackOperation(4, StackItem.push(height - 1, stackStampWithOffset(2)))); + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 1), val2, stackStampWithOffset(1))), + new IndexedStackOperation( + 4, StackItem.push((short) (height - 1), stackStampWithOffset(2)))); } private void threeOne(MessageFrame frame, StackContext pending) { @@ -106,9 +109,12 @@ private void threeOne(MessageFrame frame, StackContext pending) { pending.addArmingLine( new IndexedStackOperation(1, StackItem.pop(height, val1, stackStampWithOffset(0))), - new IndexedStackOperation(2, StackItem.pop(height - 1, val2, stackStampWithOffset(1))), - new IndexedStackOperation(3, StackItem.pop(height - 2, val3, stackStampWithOffset(2))), - new IndexedStackOperation(4, StackItem.push(height - 2, stackStampWithOffset(3)))); + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 1), val2, stackStampWithOffset(1))), + new IndexedStackOperation( + 3, StackItem.pop((short) (height - 2), val3, stackStampWithOffset(2))), + new IndexedStackOperation( + 4, StackItem.push((short) (height - 2), stackStampWithOffset(3)))); } private void loadStore(MessageFrame frame, StackContext pending) { @@ -118,7 +124,8 @@ private void loadStore(MessageFrame frame, StackContext pending) { pending.addLine( new IndexedStackOperation(1, StackItem.pop(height, val1, stackStampWithOffset(0))), - new IndexedStackOperation(4, StackItem.pop(height - 1, val2, stackStampWithOffset(1)))); + new IndexedStackOperation( + 4, StackItem.pop((short) (height - 1), val2, stackStampWithOffset(1)))); } else { Bytes val = getStack(frame, 0); @@ -133,11 +140,12 @@ private void dup(MessageFrame frame, StackContext pending) { Bytes val = getStack(frame, depth); pending.addLine( - new IndexedStackOperation(1, StackItem.pop(height - depth, val, stackStampWithOffset(0))), new IndexedStackOperation( - 2, StackItem.pushImmediate(height - depth, val, stackStampWithOffset(1))), + 1, StackItem.pop((short) (height - depth), val, stackStampWithOffset(0))), new IndexedStackOperation( - 4, StackItem.pushImmediate(height + 1, val, stackStampWithOffset(2)))); + 2, StackItem.pushImmediate((short) (height - depth), val, stackStampWithOffset(1))), + new IndexedStackOperation( + 4, StackItem.pushImmediate((short) (height + 1), val, stackStampWithOffset(2)))); } private void swap(MessageFrame frame, StackContext pending) { @@ -146,10 +154,11 @@ private void swap(MessageFrame frame, StackContext pending) { Bytes val2 = getStack(frame, depth); pending.addLine( - new IndexedStackOperation(1, StackItem.pop(height - depth, val1, stackStampWithOffset(0))), + new IndexedStackOperation( + 1, StackItem.pop((short) (height - depth), val1, stackStampWithOffset(0))), new IndexedStackOperation(2, StackItem.pop(height, val2, stackStampWithOffset(1))), new IndexedStackOperation( - 3, StackItem.pushImmediate(height - depth, val2, stackStampWithOffset(2))), + 3, StackItem.pushImmediate((short) (height - depth), val2, stackStampWithOffset(2))), new IndexedStackOperation( 4, StackItem.pushImmediate(height, val1, stackStampWithOffset(3)))); } @@ -161,7 +170,8 @@ private void log(MessageFrame frame, StackContext pending) { // Stack line 1 pending.addLine( new IndexedStackOperation(1, StackItem.pop(height, offset, stackStampWithOffset(0))), - new IndexedStackOperation(2, StackItem.pop(height - 1, size, stackStampWithOffset(1)))); + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 1), size, stackStampWithOffset(1)))); // Stack line 2 IndexedStackOperation[] line2 = new IndexedStackOperation[] {}; @@ -173,7 +183,7 @@ private void log(MessageFrame frame, StackContext pending) { line2 = new IndexedStackOperation[] { new IndexedStackOperation( - 1, StackItem.pop(height - 2, topic1, stackStampWithOffset(0))), + 1, StackItem.pop((short) (height - 2), topic1, stackStampWithOffset(0))), }; } case LOG2 -> { @@ -183,9 +193,9 @@ private void log(MessageFrame frame, StackContext pending) { line2 = new IndexedStackOperation[] { new IndexedStackOperation( - 1, StackItem.pop(height - 2, topic1, stackStampWithOffset(2))), + 1, StackItem.pop((short) (height - 2), topic1, stackStampWithOffset(2))), new IndexedStackOperation( - 2, StackItem.pop(height - 3, topic2, stackStampWithOffset(3))), + 2, StackItem.pop((short) (height - 3), topic2, stackStampWithOffset(3))), }; } case LOG3 -> { @@ -196,11 +206,11 @@ private void log(MessageFrame frame, StackContext pending) { line2 = new IndexedStackOperation[] { new IndexedStackOperation( - 1, StackItem.pop(height - 2, topic1, stackStampWithOffset(2))), + 1, StackItem.pop((short) (height - 2), topic1, stackStampWithOffset(2))), new IndexedStackOperation( - 2, StackItem.pop(height - 3, topic2, stackStampWithOffset(3))), + 2, StackItem.pop((short) (height - 3), topic2, stackStampWithOffset(3))), new IndexedStackOperation( - 3, StackItem.pop(height - 4, topic3, stackStampWithOffset(4))), + 3, StackItem.pop((short) (height - 4), topic3, stackStampWithOffset(4))), }; } case LOG4 -> { @@ -212,13 +222,13 @@ private void log(MessageFrame frame, StackContext pending) { line2 = new IndexedStackOperation[] { new IndexedStackOperation( - 1, StackItem.pop(height - 2, topic1, stackStampWithOffset(2))), + 1, StackItem.pop((short) (height - 2), topic1, stackStampWithOffset(2))), new IndexedStackOperation( - 2, StackItem.pop(height - 3, topic2, stackStampWithOffset(3))), + 2, StackItem.pop((short) (height - 3), topic2, stackStampWithOffset(3))), new IndexedStackOperation( - 3, StackItem.pop(height - 4, topic3, stackStampWithOffset(4))), + 3, StackItem.pop((short) (height - 4), topic3, stackStampWithOffset(4))), new IndexedStackOperation( - 4, StackItem.pop(height - 5, topic4, stackStampWithOffset(5))), + 4, StackItem.pop((short) (height - 5), topic4, stackStampWithOffset(5))), }; } default -> throw new RuntimeException("not a LOGx"); @@ -234,9 +244,12 @@ private void copy(MessageFrame frame, StackContext pending) { Bytes val3 = getStack(frame, 3); pending.addLine( - new IndexedStackOperation(1, StackItem.pop(height - 1, val1, stackStampWithOffset(1))), - new IndexedStackOperation(2, StackItem.pop(height - 3, val3, stackStampWithOffset(2))), - new IndexedStackOperation(3, StackItem.pop(height - 2, val2, stackStampWithOffset(3))), + new IndexedStackOperation( + 1, StackItem.pop((short) (height - 1), val1, stackStampWithOffset(1))), + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 3), val3, stackStampWithOffset(2))), + new IndexedStackOperation( + 3, StackItem.pop((short) (height - 2), val2, stackStampWithOffset(3))), new IndexedStackOperation(4, StackItem.pop(height, val0, stamp))); } else { Bytes val1 = getStack(frame, 0); @@ -245,8 +258,10 @@ private void copy(MessageFrame frame, StackContext pending) { pending.addLine( new IndexedStackOperation(1, StackItem.pop(height, val1, stackStampWithOffset(1))), - new IndexedStackOperation(2, StackItem.pop(height - 2, val2, stackStampWithOffset(2))), - new IndexedStackOperation(3, StackItem.pop(height - 1, val3, stackStampWithOffset(3)))); + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 2), val2, stackStampWithOffset(2))), + new IndexedStackOperation( + 3, StackItem.pop((short) (height - 1), val3, stackStampWithOffset(3)))); } } @@ -264,27 +279,40 @@ private void call(MessageFrame frame, StackContext pending) { Bytes val7 = getStack(frame, 6); pending.addLine( - new IndexedStackOperation(1, StackItem.pop(height - 3, val4, stackStampWithOffset(3))), - new IndexedStackOperation(2, StackItem.pop(height - 4, val5, stackStampWithOffset(4))), - new IndexedStackOperation(3, StackItem.pop(height - 5, val6, stackStampWithOffset(5))), - new IndexedStackOperation(4, StackItem.pop(height - 6, val7, stackStampWithOffset(6)))); + new IndexedStackOperation( + 1, StackItem.pop((short) (height - 3), val4, stackStampWithOffset(3))), + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 4), val5, stackStampWithOffset(4))), + new IndexedStackOperation( + 3, StackItem.pop((short) (height - 5), val6, stackStampWithOffset(5))), + new IndexedStackOperation( + 4, StackItem.pop((short) (height - 6), val7, stackStampWithOffset(6)))); pending.addArmingLine( new IndexedStackOperation(1, StackItem.pop(height, val1, stackStampWithOffset(0))), - new IndexedStackOperation(2, StackItem.pop(height - 1, val2, stackStampWithOffset(1))), - new IndexedStackOperation(3, StackItem.pop(height - 2, val3, stackStampWithOffset(2))), - new IndexedStackOperation(4, StackItem.push(height - 6, stackStampWithOffset(7)))); + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 1), val2, stackStampWithOffset(1))), + new IndexedStackOperation( + 3, StackItem.pop((short) (height - 2), val3, stackStampWithOffset(2))), + new IndexedStackOperation( + 4, StackItem.push((short) (height - 6), stackStampWithOffset(7)))); } else { pending.addLine( - new IndexedStackOperation(1, StackItem.pop(height - 2, val3, stackStampWithOffset(3))), - new IndexedStackOperation(2, StackItem.pop(height - 3, val4, stackStampWithOffset(4))), - new IndexedStackOperation(3, StackItem.pop(height - 4, val5, stackStampWithOffset(5))), - new IndexedStackOperation(4, StackItem.pop(height - 5, val6, stackStampWithOffset(6)))); + new IndexedStackOperation( + 1, StackItem.pop((short) (height - 2), val3, stackStampWithOffset(3))), + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 3), val4, stackStampWithOffset(4))), + new IndexedStackOperation( + 3, StackItem.pop((short) (height - 4), val5, stackStampWithOffset(5))), + new IndexedStackOperation( + 4, StackItem.pop((short) (height - 5), val6, stackStampWithOffset(6)))); pending.addArmingLine( new IndexedStackOperation(1, StackItem.pop(height, val1, stackStampWithOffset(0))), - new IndexedStackOperation(2, StackItem.pop(height - 1, val2, stackStampWithOffset(1))), - new IndexedStackOperation(4, StackItem.push(height - 5, stackStampWithOffset(7)))); + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 1), val2, stackStampWithOffset(1))), + new IndexedStackOperation( + 4, StackItem.push((short) (height - 5), stackStampWithOffset(7)))); } } @@ -293,17 +321,21 @@ private void create(MessageFrame frame, StackContext pending) { final Bytes val2 = getStack(frame, 2); pending.addLine( - new IndexedStackOperation(1, StackItem.pop(height - 1, val1, stackStampWithOffset(1))), - new IndexedStackOperation(2, StackItem.pop(height - 2, val2, stackStampWithOffset(2)))); + new IndexedStackOperation( + 1, StackItem.pop((short) (height - 1), val1, stackStampWithOffset(1))), + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 2), val2, stackStampWithOffset(2)))); // case CREATE2 if (currentOpcodeData.stackSettings().flag2()) { final Bytes val3 = getStack(frame, 3); final Bytes val4 = getStack(frame, 0); pending.addArmingLine( - new IndexedStackOperation(2, StackItem.pop(height - 3, val3, stackStampWithOffset(3))), + new IndexedStackOperation( + 2, StackItem.pop((short) (height - 3), val3, stackStampWithOffset(3))), new IndexedStackOperation(3, StackItem.pop(height, val4, stackStampWithOffset(0))), - new IndexedStackOperation(4, StackItem.push(height - 3, stackStampWithOffset(4)))); + new IndexedStackOperation( + 4, StackItem.push((short) (height - 3), stackStampWithOffset(4)))); } else // case CREATE { @@ -311,7 +343,8 @@ private void create(MessageFrame frame, StackContext pending) { pending.addArmingLine( new IndexedStackOperation(3, StackItem.pop(height, val4, stackStampWithOffset(0))), - new IndexedStackOperation(4, StackItem.push(height - 2, stackStampWithOffset(4)))); + new IndexedStackOperation( + 4, StackItem.push((short) (height - 2), stackStampWithOffset(4)))); } } @@ -346,7 +379,7 @@ public void processInstruction(final Hub hub, MessageFrame frame, int stackStamp final int delta = currentOpcodeData.stackSettings().delta(); Preconditions.checkArgument(heightNew == frame.stackSize()); - height = frame.stackSize(); + height = (short) frame.stackSize(); heightNew += currentOpcodeData.stackSettings().nbAdded(); heightNew -= currentOpcodeData.stackSettings().nbRemoved(); diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/runtime/stack/StackItem.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/runtime/stack/StackItem.java index b473f8cc31..d65917fd08 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/runtime/stack/StackItem.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/runtime/stack/StackItem.java @@ -36,7 +36,7 @@ public final class StackItem { * The relative height of the element with regard to the stack height just before executing the * linked EVM instruction. */ - @Getter private final int height; + @Getter private final short height; /** The value having been popped from/pushed on the stack. */ @Getter @Setter private Bytes value; @@ -65,23 +65,23 @@ private StackItem() { this.stackStamp = 0; } - StackItem(int height, Bytes value, byte action, int stackStamp) { + StackItem(short height, Bytes value, byte action, int stackStamp) { this.height = height; this.value = value; this.action = action; this.stackStamp = stackStamp; } - public static StackItem pop(int height, Bytes value, int stackStamp) { + public static StackItem pop(short height, Bytes value, int stackStamp) { return new StackItem(height, value, Stack.POP, stackStamp); } - public static StackItem push(int height, int stackStamp) { + public static StackItem push(short height, int stackStamp) { return new StackItem( height, MARKER /* marker value, erased on unlatching */, Stack.PUSH, stackStamp); } - public static StackItem pushImmediate(int height, Bytes val, int stackStamp) { + public static StackItem pushImmediate(short height, Bytes val, int stackStamp) { return new StackItem(height, val.copy(), Stack.PUSH, stackStamp); } }