Skip to content

Commit

Permalink
perf: defer computing EXT components at trace-time (#414)
Browse files Browse the repository at this point in the history
Fixes #412 

Signed-off-by: [email protected]
  • Loading branch information
delehef authored Nov 21, 2023
1 parent f9efcfb commit 57b2da7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ public void popTransaction() {
@Override
public void tracePreOpcode(final MessageFrame frame) {
final OpCodeData opCode = OpCodes.of(frame.getCurrentOperation().getOpcode());
final Bytes32 arg1 = Bytes32.leftPad(frame.getStackItem(0));
final Bytes32 arg2 = Bytes32.leftPad(frame.getStackItem(1));
final Bytes32 arg3 = Bytes32.leftPad(frame.getStackItem(2));

this.operations.add(new ExtOperation(opCode, arg1, arg2, arg3));
this.operations.add(
new ExtOperation(
opCode.mnemonic(),
Bytes32.leftPad(frame.getStackItem(0)),
Bytes32.leftPad(frame.getStackItem(1)),
Bytes32.leftPad(frame.getStackItem(2))));
}

public void traceExtOperation(ExtOperation op, Trace trace) {
Expand Down Expand Up @@ -213,6 +214,7 @@ public List<ColumnHeader> columnsHeaders() {
public void commit(List<MappedByteBuffer> buffers) {
final Trace trace = new Trace(buffers);
for (ExtOperation operation : this.operations) {
operation.setup();
this.traceExtOperation(operation, trace);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import net.consensys.linea.zktracer.bytestheta.BytesArray;
import net.consensys.linea.zktracer.module.ext.calculator.AbstractExtCalculator;
import net.consensys.linea.zktracer.opcode.OpCode;
import net.consensys.linea.zktracer.opcode.OpCodeData;
import org.apache.tuweni.bytes.Bytes32;
import org.apache.tuweni.units.bigints.UInt256;

Expand All @@ -37,14 +36,15 @@ public class ExtOperation {
@Getter private final BaseBytes arg1;
@Getter private final BaseBytes arg2;
@Getter private final BaseBytes arg3;
@Getter private final BaseTheta result;
@Getter private final BaseTheta aBytes;
@Getter private final BaseTheta bBytes;
@Getter private final BaseTheta cBytes;

@Getter private BaseTheta result;
@Getter private BaseTheta aBytes;
@Getter private BaseTheta bBytes;
@Getter private BaseTheta cBytes;
@Getter private BaseTheta deltaBytes;
@Getter private final BytesArray hBytes;
@Getter private final BaseTheta rBytes;
@Getter private final BytesArray iBytes;
@Getter private BytesArray hBytes;
@Getter private BaseTheta rBytes;
@Getter private BytesArray iBytes;
@Getter private BytesArray jBytes;
@Getter private BytesArray qBytes;
@Getter private boolean[] cmp = new boolean[8];
Expand All @@ -62,36 +62,39 @@ public int hashCode() {
return Objects.hash(this.opCode, this.arg1, this.arg2, this.arg3);
}

public ExtOperation(OpCodeData opCodeData, Bytes32 arg1, Bytes32 arg2, Bytes32 arg3) {
this(opCodeData.mnemonic(), arg1, arg2, arg3);
}

public ExtOperation(OpCode opCode, Bytes32 arg1, Bytes32 arg2, Bytes32 arg3) {
this.opCode = opCode;
this.arg1 = BaseBytes.fromBytes32(arg1);
this.arg2 = BaseBytes.fromBytes32(arg2);
this.arg3 = BaseBytes.fromBytes32(arg3);
this.aBytes = BaseTheta.fromBytes32(arg1);
this.bBytes = BaseTheta.fromBytes32(arg2);
this.cBytes = BaseTheta.fromBytes32(arg3);
this.arg1 = BaseBytes.fromBytes32(arg1.copy());
this.arg2 = BaseBytes.fromBytes32(arg2.copy());
this.arg3 = BaseBytes.fromBytes32(arg3.copy());
this.oli = isOneLineInstruction();
}

public void setup() {
this.aBytes = BaseTheta.fromBytes32(this.arg1.getBytes32());
this.bBytes = BaseTheta.fromBytes32(this.arg2.getBytes32());
this.cBytes = BaseTheta.fromBytes32(this.arg3.getBytes32());
this.iBytes = new BytesArray(7);
this.jBytes = new BytesArray(8);
this.qBytes = new BytesArray(8);
this.deltaBytes = BaseTheta.fromBytes32(Bytes32.ZERO);
this.hBytes = new BytesArray(6);

AbstractExtCalculator computer = AbstractExtCalculator.create(opCode);
UInt256 result = computer.computeResult(arg1, arg2, arg3);
UInt256 result =
computer.computeResult(
this.arg1.getBytes32(), this.arg2.getBytes32(), this.arg3.getBytes32());

this.result = BaseTheta.fromBytes32(result);
this.rBytes = BaseTheta.fromBytes32(result);

this.oli = isOneLineInstruction();
if (!this.oli) {
cmp = computer.computeComparisonFlags(cBytes, rBytes);
deltaBytes = computer.computeDeltas(cBytes, rBytes);
jBytes = computer.computeJs(arg1, arg2);
qBytes = computer.computeQs(arg1, arg2, arg3);
jBytes = computer.computeJs(this.arg1.getBytes32(), this.arg2.getBytes32());
qBytes =
computer.computeQs(
this.arg1.getBytes32(), this.arg2.getBytes32(), this.arg3.getBytes32());
overflowH = computer.computeHs(aBytes, bBytes, hBytes);
overflowI = computer.computeIs(qBytes, cBytes, iBytes);
overflowJ = computer.computeOverflowJ(qBytes, cBytes, rBytes, iBytes, getSigma(), getTau());
Expand All @@ -110,9 +113,7 @@ public boolean getBit2() {
}

public boolean getBit3() {
UInt256 uInt256 = UInt256.fromBytes(this.arg3.getBytes32());

return UInt256.ONE.compareTo(uInt256) >= 0;
return UInt256.ONE.compareTo(UInt256.fromBytes(this.arg3.getBytes32())) >= 0;
}

/** Returns true if any of the bit1, bit2, or bit3 flags are set. */
Expand Down

0 comments on commit 57b2da7

Please sign in to comment.