diff --git a/prover/backend/blobsubmission/blobcompression_test.go b/prover/backend/blobsubmission/blobcompression_test.go index d55d05c1b..e8ae12702 100644 --- a/prover/backend/blobsubmission/blobcompression_test.go +++ b/prover/backend/blobsubmission/blobcompression_test.go @@ -5,12 +5,11 @@ import ( "crypto/sha256" "encoding/json" "fmt" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" "os" "strings" "testing" - blob "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1" - fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/zkevm-monorepo/prover/utils" gokzg4844 "github.com/crate-crypto/go-kzg-4844" @@ -270,7 +269,7 @@ func TestKZGWithPoint(t *testing.T) { } // Compute all the prover fields - snarkHash, err := blob.MiMCChecksumPackedData(blobBytes[:], fr381.Bits-1, blob.NoTerminalSymbol()) + snarkHash, err := encode.MiMCChecksumPackedData(blobBytes[:], fr381.Bits-1, encode.NoTerminalSymbol()) assert.NoError(t, err) xUnreduced := evaluationChallenge(snarkHash, blobHash[:]) diff --git a/prover/backend/blobsubmission/craft.go b/prover/backend/blobsubmission/craft.go index 0cdc84215..913a3cb86 100644 --- a/prover/backend/blobsubmission/craft.go +++ b/prover/backend/blobsubmission/craft.go @@ -4,12 +4,12 @@ import ( "encoding/base64" "errors" "fmt" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" "hash" "github.com/consensys/zkevm-monorepo/prover/crypto/mimc" fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - blob "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1" "github.com/consensys/zkevm-monorepo/prover/utils" "golang.org/x/crypto/sha3" ) @@ -72,7 +72,7 @@ func CraftResponseCalldata(req *Request) (*Response, error) { } // Compute all the prover fields - snarkHash, err := blob.MiMCChecksumPackedData(compressedStream, fr381.Bits-1, blob.NoTerminalSymbol()) + snarkHash, err := encode.MiMCChecksumPackedData(compressedStream, fr381.Bits-1, encode.NoTerminalSymbol()) if err != nil { return nil, fmt.Errorf("crafting response: could not compute snark hash: %w", err) } diff --git a/prover/backend/blobsubmission/craft_eip4844.go b/prover/backend/blobsubmission/craft_eip4844.go index 77e896a2a..e7279b435 100644 --- a/prover/backend/blobsubmission/craft_eip4844.go +++ b/prover/backend/blobsubmission/craft_eip4844.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "errors" "fmt" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" blob "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1" @@ -91,7 +92,7 @@ func CraftResponse(req *Request) (*Response, error) { } // Compute all the prover fields - snarkHash, err := blob.MiMCChecksumPackedData(append(compressedStream, make([]byte, blob.MaxUsableBytes-len(compressedStream))...), fr381.Bits-1, blob.NoTerminalSymbol()) + snarkHash, err := encode.MiMCChecksumPackedData(append(compressedStream, make([]byte, blob.MaxUsableBytes-len(compressedStream))...), fr381.Bits-1, encode.NoTerminalSymbol()) if err != nil { return nil, fmt.Errorf("crafting response: could not compute snark hash: %w", err) } diff --git a/prover/backend/ethereum/tx_encoding.go b/prover/backend/ethereum/tx_encoding.go index 9527e37f3..ef8b6a553 100644 --- a/prover/backend/ethereum/tx_encoding.go +++ b/prover/backend/ethereum/tx_encoding.go @@ -146,15 +146,15 @@ func decodeDynamicFeeTx(b *bytes.Reader, tx *types.Transaction) (err error) { parsedTx := types.DynamicFeeTx{} err = errors.Join( - tryCast(&parsedTx.ChainID, decTx[0], "chainID"), - tryCast(&parsedTx.Nonce, decTx[1], "nonce"), - tryCast(&parsedTx.GasTipCap, decTx[2], "gas-tip-cap"), - tryCast(&parsedTx.GasFeeCap, decTx[3], "gas-fee-cap"), - tryCast(&parsedTx.Gas, decTx[4], "gas"), - tryCast(&parsedTx.To, decTx[5], "to"), - tryCast(&parsedTx.Value, decTx[6], "value"), - tryCast(&parsedTx.Data, decTx[7], "data"), - tryCast(&parsedTx.AccessList, decTx[8], "access-list"), + TryCast(&parsedTx.ChainID, decTx[0], "chainID"), + TryCast(&parsedTx.Nonce, decTx[1], "nonce"), + TryCast(&parsedTx.GasTipCap, decTx[2], "gas-tip-cap"), + TryCast(&parsedTx.GasFeeCap, decTx[3], "gas-fee-cap"), + TryCast(&parsedTx.Gas, decTx[4], "gas"), + TryCast(&parsedTx.To, decTx[5], "to"), + TryCast(&parsedTx.Value, decTx[6], "value"), + TryCast(&parsedTx.Data, decTx[7], "data"), + TryCast(&parsedTx.AccessList, decTx[8], "access-list"), ) *tx = *types.NewTx(&parsedTx) return err @@ -176,14 +176,14 @@ func decodeAccessListTx(b *bytes.Reader, tx *types.Transaction) (err error) { parsedTx := types.AccessListTx{} err = errors.Join( - tryCast(&parsedTx.ChainID, decTx[0], "chainID"), - tryCast(&parsedTx.Nonce, decTx[1], "nonce"), - tryCast(&parsedTx.GasPrice, decTx[2], "gas-price"), - tryCast(&parsedTx.Gas, decTx[3], "gas"), - tryCast(&parsedTx.To, decTx[4], "to"), - tryCast(&parsedTx.Value, decTx[5], "value"), - tryCast(&parsedTx.Data, decTx[6], "data"), - tryCast(&parsedTx.AccessList, decTx[7], "access-list"), + TryCast(&parsedTx.ChainID, decTx[0], "chainID"), + TryCast(&parsedTx.Nonce, decTx[1], "nonce"), + TryCast(&parsedTx.GasPrice, decTx[2], "gas-price"), + TryCast(&parsedTx.Gas, decTx[3], "gas"), + TryCast(&parsedTx.To, decTx[4], "to"), + TryCast(&parsedTx.Value, decTx[5], "value"), + TryCast(&parsedTx.Data, decTx[6], "data"), + TryCast(&parsedTx.AccessList, decTx[7], "access-list"), ) *tx = *types.NewTx(&parsedTx) @@ -211,22 +211,22 @@ func decodeLegacyTx(b *bytes.Reader, tx *types.Transaction) (err error) { parsedTx := types.LegacyTx{} err = errors.Join( - tryCast(&parsedTx.Nonce, decTx[0], "nonce"), - tryCast(&parsedTx.GasPrice, decTx[1], "gas-price"), - tryCast(&parsedTx.Gas, decTx[2], "gas"), - tryCast(&parsedTx.To, decTx[3], "to"), - tryCast(&parsedTx.Value, decTx[4], "value"), - tryCast(&parsedTx.Data, decTx[5], "data"), + TryCast(&parsedTx.Nonce, decTx[0], "nonce"), + TryCast(&parsedTx.GasPrice, decTx[1], "gas-price"), + TryCast(&parsedTx.Gas, decTx[2], "gas"), + TryCast(&parsedTx.To, decTx[3], "to"), + TryCast(&parsedTx.Value, decTx[4], "value"), + TryCast(&parsedTx.Data, decTx[5], "data"), ) *tx = *types.NewTx(&parsedTx) return err } -// tryCast will attempt to set t with the underlying value of `from` will return +// TryCast will attempt to set t with the underlying value of `from` will return // an error if the type does not match. The explainer string is used to generate // the error if any. -func tryCast[T any](into *T, from any, explainer string) error { +func TryCast[T any](into *T, from any, explainer string) error { if into == nil || from == nil { return fmt.Errorf("from or into is/are nil") @@ -234,7 +234,7 @@ func tryCast[T any](into *T, from any, explainer string) error { // The rlp encoding is not "type-aware", if the underlying field is an // access-list, it will decode into []interface{} (and we recursively parse - // it) otherwise, it always decode to `[]byte` + // it) otherwise, it always decodes to `[]byte` if list, ok := (from).([]interface{}); ok { var ( @@ -249,7 +249,7 @@ func tryCast[T any](into *T, from any, explainer string) error { for i := range accessList { err = errors.Join( err, - tryCast(&accessList[i], list[i], fmt.Sprintf("%v[%v]", explainer, i)), + TryCast(&accessList[i], list[i], fmt.Sprintf("%v[%v]", explainer, i)), ) } *into = (any(accessList)).(T) @@ -258,8 +258,8 @@ func tryCast[T any](into *T, from any, explainer string) error { case types.AccessTuple: tuple := types.AccessTuple{} err = errors.Join( - tryCast(&tuple.Address, list[0], fmt.Sprintf("%v.%v", explainer, "address")), - tryCast(&tuple.StorageKeys, list[1], fmt.Sprintf("%v.%v", explainer, "storage-key")), + TryCast(&tuple.Address, list[0], fmt.Sprintf("%v.%v", explainer, "address")), + TryCast(&tuple.StorageKeys, list[1], fmt.Sprintf("%v.%v", explainer, "storage-key")), ) *into = (any(tuple)).(T) return err @@ -267,7 +267,7 @@ func tryCast[T any](into *T, from any, explainer string) error { case []common.Hash: hashes := make([]common.Hash, length) for i := range hashes { - tryCast(&hashes[i], list[i], fmt.Sprintf("%v[%v]", explainer, i)) + TryCast(&hashes[i], list[i], fmt.Sprintf("%v[%v]", explainer, i)) } *into = (any(hashes)).(T) return err @@ -295,7 +295,7 @@ func tryCast[T any](into *T, from any, explainer string) error { *into = any(address).(T) case common.Hash: // Parse the bytes as an UTF8 string (= direct casting in go). - // Then, the string as an hexstring encoded address. + // Then, the string as a hexstring encoded address. hash := common.BytesToHash(fromBytes) *into = any(hash).(T) case *big.Int: diff --git a/prover/circuits/blobdecompression/v0/assign_test.go b/prover/circuits/blobdecompression/v0/assign_test.go index b375dfc8c..8398ee3c4 100644 --- a/prover/circuits/blobdecompression/v0/assign_test.go +++ b/prover/circuits/blobdecompression/v0/assign_test.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" v0 "github.com/consensys/zkevm-monorepo/prover/circuits/blobdecompression/v0" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1/test_utils" "os" "testing" @@ -73,7 +74,9 @@ func mustGetTestCompressedData(t *testing.T) (resp blobsubmission.Response, blob blobBytes, err = base64.StdEncoding.DecodeString(resp.CompressedData) assert.NoError(t, err) - _, _, _, err = blob.DecompressBlob(blobBytes, dict) + dictStore, err := dictionary.SingletonStore(dict, 0) + assert.NoError(t, err) + _, _, _, err = blob.DecompressBlob(blobBytes, dictStore) assert.NoError(t, err) return diff --git a/prover/circuits/blobdecompression/v0/prelude.go b/prover/circuits/blobdecompression/v0/prelude.go index b8916e721..dbf58aaa5 100644 --- a/prover/circuits/blobdecompression/v0/prelude.go +++ b/prover/circuits/blobdecompression/v0/prelude.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/consensys/zkevm-monorepo/prover/circuits/internal" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -54,7 +55,12 @@ func Assign(blobData, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.Elem return } - header, uncompressedData, _, err := blob.DecompressBlob(blobData, dict) + dictStore, err := dictionary.SingletonStore(dict, 0) + if err != nil { + err = fmt.Errorf("failed to create dictionary store %w", err) + return + } + header, uncompressedData, _, err := blob.DecompressBlob(blobData, dictStore) if err != nil { err = fmt.Errorf("decompression circuit assignment : could not decompress the data : %w", err) return diff --git a/prover/circuits/blobdecompression/v1/assign_test.go b/prover/circuits/blobdecompression/v1/assign_test.go index 9496770cf..abf6e8b61 100644 --- a/prover/circuits/blobdecompression/v1/assign_test.go +++ b/prover/circuits/blobdecompression/v1/assign_test.go @@ -5,6 +5,7 @@ package v1_test import ( "encoding/base64" "encoding/hex" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -27,7 +28,9 @@ func prepareTestBlob(t require.TestingT) (c, a frontend.Circuit) { func prepare(t require.TestingT, blobBytes []byte) (c *v1.Circuit, a frontend.Circuit) { - _, payload, _, err := blobcompressorv1.DecompressBlob(blobBytes, blobtestutils.GetDict(t)) + dictStore, err := dictionary.SingletonStore(blobtestutils.GetDict(t), 1) + assert.NoError(t, err) + _, payload, _, err := blobcompressorv1.DecompressBlob(blobBytes, dictStore) assert.NoError(t, err) resp, err := blobsubmission.CraftResponse(&blobsubmission.Request{ diff --git a/prover/circuits/blobdecompression/v1/circuit.go b/prover/circuits/blobdecompression/v1/circuit.go index 194007dc1..590aead53 100644 --- a/prover/circuits/blobdecompression/v1/circuit.go +++ b/prover/circuits/blobdecompression/v1/circuit.go @@ -4,6 +4,8 @@ import ( "bytes" "errors" "fmt" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" "math/big" "github.com/consensys/gnark-crypto/ecc" @@ -235,7 +237,12 @@ func AssignFPI(blobBytes, dict []byte, eip4844Enabled bool, x [32]byte, y fr381. return } - header, payload, _, err := blob.DecompressBlob(blobBytes, dict) + dictStore, err := dictionary.SingletonStore(dict, 1) + if err != nil { + err = fmt.Errorf("failed to create dictionary store %w", err) + return + } + header, payload, _, err := blob.DecompressBlob(blobBytes, dictStore) if err != nil { return } @@ -266,7 +273,7 @@ func AssignFPI(blobBytes, dict []byte, eip4844Enabled bool, x [32]byte, y fr381. if len(blobBytes) != 128*1024 { panic("blobBytes length is not 128*1024") } - fpi.SnarkHash, err = blob.MiMCChecksumPackedData(blobBytes, fr381.Bits-1, blob.NoTerminalSymbol()) // TODO if forced to remove the above check, pad with zeros + fpi.SnarkHash, err = encode.MiMCChecksumPackedData(blobBytes, fr381.Bits-1, encode.NoTerminalSymbol()) // TODO if forced to remove the above check, pad with zeros return } diff --git a/prover/circuits/blobdecompression/v1/snark_test.go b/prover/circuits/blobdecompression/v1/snark_test.go index fb1635b93..8f37102e9 100644 --- a/prover/circuits/blobdecompression/v1/snark_test.go +++ b/prover/circuits/blobdecompression/v1/snark_test.go @@ -4,6 +4,8 @@ import ( "bytes" "crypto/rand" "errors" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" "github.com/consensys/zkevm-monorepo/prover/utils" "testing" @@ -19,7 +21,7 @@ import ( "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" blob "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1" - blobtesting "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1/test_utils" + blobtestutils "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1/test_utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -29,7 +31,7 @@ func TestParseHeader(t *testing.T) { maxBlobSize := 1024 blobs := [][]byte{ - blobtesting.GenTestBlob(t, 100000), + blobtestutils.GenTestBlob(t, 100000), } for _, blobData := range blobs { @@ -47,14 +49,17 @@ func TestParseHeader(t *testing.T) { test.NoTestEngine(), } + dictStore, err := dictionary.SingletonStore(blobtestutils.GetDict(t), 1) + assert.NoError(t, err) + for _, blobData := range blobs { - header, _, blocks, err := blob.DecompressBlob(blobData, blobtesting.GetDict(t)) + header, _, blocks, err := blob.DecompressBlob(blobData, dictStore) assert.NoError(t, err) assert.LessOrEqual(t, len(blocks), MaxNbBatches, "too many batches") - unpacked, err := blob.UnpackAlign(blobData, fr381.Bits-1, false) + unpacked, err := encode.UnpackAlign(blobData, fr381.Bits-1, false) require.NoError(t, err) assignment := &testParseHeaderCircuit{ @@ -87,9 +92,9 @@ func TestChecksumBatches(t *testing.T) { var batchEndss [nbAssignments][]int for i := range batchEndss { - batchEndss[i] = make([]int, blobtesting.RandIntn(MaxNbBatches)+1) + batchEndss[i] = make([]int, blobtestutils.RandIntn(MaxNbBatches)+1) for j := range batchEndss[i] { - batchEndss[i][j] = 31 + blobtesting.RandIntn(62) + batchEndss[i][j] = 31 + blobtestutils.RandIntn(62) if j > 0 { batchEndss[i][j] += batchEndss[i][j-1] } @@ -160,7 +165,7 @@ func testChecksumBatches(t *testing.T, blob []byte, batchEndss ...[]int) { Sums: sums, NbBatches: len(batchEnds), } - assignment.Sums[blobtesting.RandIntn(len(batchEnds))] = 3 + assignment.Sums[blobtestutils.RandIntn(len(batchEnds))] = 3 assert.Error(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_377.ScalarField())) @@ -223,7 +228,7 @@ func TestUnpackCircuit(t *testing.T) { runTest := func(b []byte) { var packedBuf bytes.Buffer - _, err := blob.PackAlign(&packedBuf, b, fr381.Bits-1) // todo use two different slices + _, err := encode.PackAlign(&packedBuf, b, fr381.Bits-1) // todo use two different slices assert.NoError(t, err) circuit := unpackCircuit{ @@ -307,7 +312,7 @@ func TestBlobChecksum(t *testing.T) { // aka "snark hash" assignment := testDataChecksumCircuit{ DataBytes: dataVarsPadded[:nPadded], } - assignment.Checksum, err = blob.MiMCChecksumPackedData(dataPadded[:nPadded], fr381.Bits-1, blob.NoTerminalSymbol()) + assignment.Checksum, err = encode.MiMCChecksumPackedData(dataPadded[:nPadded], fr381.Bits-1, encode.NoTerminalSymbol()) assert.NoError(t, err) assert.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_377.ScalarField())) @@ -337,9 +342,11 @@ func (c *testDataChecksumCircuit) Define(api frontend.API) error { } func TestDictHash(t *testing.T) { - blobBytes := blobtesting.GenTestBlob(t, 1) - dict := blobtesting.GetDict(t) - header, _, _, err := blob.DecompressBlob(blobBytes, dict) // a bit roundabout, but the header field is not public + blobBytes := blobtestutils.GenTestBlob(t, 1) + dict := blobtestutils.GetDict(t) + dictStore, err := dictionary.SingletonStore(blobtestutils.GetDict(t), 1) + assert.NoError(t, err) + header, _, _, err := blob.DecompressBlob(blobBytes, dictStore) // a bit roundabout, but the header field is not public assert.NoError(t, err) circuit := testDataDictHashCircuit{ diff --git a/prover/lib/compressor/blob/blob.go b/prover/lib/compressor/blob/blob.go index 2ede531e7..dcf5a513a 100644 --- a/prover/lib/compressor/blob/blob.go +++ b/prover/lib/compressor/blob/blob.go @@ -1,15 +1,16 @@ package blob import ( + "bytes" "errors" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" + v0 "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v0" + v1 "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1" + "github.com/ethereum/go-ethereum/rlp" "os" "path/filepath" "strings" - - fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/zkevm-monorepo/prover/circuits/blobdecompression/v0/compress" - v1 "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1" ) func GetVersion(blob []byte) uint16 { @@ -23,17 +24,6 @@ func GetVersion(blob []byte) uint16 { return 0 } -// DictionaryChecksum according to the given spec version -func DictionaryChecksum(dict []byte, version uint16) ([]byte, error) { - switch version { - case 1: - return v1.MiMCChecksumPackedData(dict, 8) - case 0: - return compress.ChecksumPaddedBytes(dict, len(dict), hash.MIMC_BLS12_377.New(), fr381.Bits), nil - } - return nil, errors.New("unsupported version") -} - // GetRepoRootPath assumes that current working directory is within the repo func GetRepoRootPath() (string, error) { wd, err := os.Getwd() @@ -57,3 +47,35 @@ func GetDict() ([]byte, error) { dictPath := filepath.Join(repoRoot, "prover/lib/compressor/compressor_dict.bin") return os.ReadFile(dictPath) } + +func DecompressBlob(blob []byte, dictStore dictionary.Store) ([]byte, error) { + vsn := GetVersion(blob) + var ( + blockDecoder func(*bytes.Reader) (encode.DecodedBlockData, error) + blocks [][]byte + err error + ) + switch vsn { + case 0: + _, _, blocks, err = v0.DecompressBlob(blob, dictStore) + blockDecoder = v0.DecodeBlockFromUncompressed + case 1: + _, _, blocks, err = v1.DecompressBlob(blob, dictStore) + blockDecoder = v1.DecodeBlockFromUncompressed + default: + return nil, errors.New("unrecognized blob version") + } + + if err != nil { + return nil, err + } + blockObjs := make([]encode.DecodedBlockData, len(blocks)) + for i, block := range blocks { + + blockObjs[i], err = blockDecoder(bytes.NewReader(block)) + if err != nil { + return nil, err + } + } + return rlp.EncodeToBytes(blockObjs) +} diff --git a/prover/lib/compressor/blob/dictionary/dictionary.go b/prover/lib/compressor/blob/dictionary/dictionary.go new file mode 100644 index 000000000..af7c46cd1 --- /dev/null +++ b/prover/lib/compressor/blob/dictionary/dictionary.go @@ -0,0 +1,76 @@ +package dictionary + +import ( + "bytes" + "errors" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/hash" + "github.com/consensys/zkevm-monorepo/prover/circuits/blobdecompression/v0/compress" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" + "os" +) + +// Checksum according to the given spec version +func Checksum(dict []byte, version uint16) ([]byte, error) { + switch version { + case 1: + return encode.MiMCChecksumPackedData(dict, 8) + case 0: + return compress.ChecksumPaddedBytes(dict, len(dict), hash.MIMC_BLS12_377.New(), fr.Bits), nil + } + return nil, errors.New("unsupported version") +} + +type Store []map[string][]byte + +func NewStore() Store { + res := make(Store, 2) + for i := range res { + res[i] = make(map[string][]byte) + } + return res +} + +func SingletonStore(dict []byte, version uint16) (Store, error) { + s := make(Store, version+1) + key, err := Checksum(dict, version) + s[version] = make(map[string][]byte, 1) + s[version][string(key)] = dict + return s, err +} + +func (s Store) Load(paths ...string) error { + loadVsn := func(vsn uint16) error { + for _, path := range paths { + dict, err := os.ReadFile(path) + if err != nil { + return err + } + + checksum, err := Checksum(dict, vsn) + if err != nil { + return err + } + key := string(checksum) + existing, exists := s[vsn][key] + if exists && !bytes.Equal(dict, existing) { // should be incredibly unlikely + return errors.New("unmatching dictionary found") + } + s[vsn][key] = dict + } + return nil + } + + return errors.Join(loadVsn(0), loadVsn(1)) +} + +func (s Store) Get(checksum []byte, version uint16) ([]byte, error) { + if int(version) > len(s) { + return nil, errors.New("unrecognized blob version") + } + res, ok := s[version][string(checksum)] + if !ok { + return nil, errors.New("dictionary not found") + } + return res, nil +} diff --git a/prover/lib/compressor/blob/encode/encode.go b/prover/lib/compressor/blob/encode/encode.go new file mode 100644 index 000000000..53f7dbbdb --- /dev/null +++ b/prover/lib/compressor/blob/encode/encode.go @@ -0,0 +1,231 @@ +package encode + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/hash" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/icza/bitio" + "io" +) + +// UnpackAlign unpacks r (packed with PackAlign) and returns the unpacked data. +func UnpackAlign(r []byte, packingSize int, noTerminalSymbol bool) ([]byte, error) { + bytesPerElem := (packingSize + 7) / 8 + packingSizeLastU64 := uint8(packingSize % 64) + if packingSizeLastU64 == 0 { + packingSizeLastU64 = 64 + } + + n := len(r) / bytesPerElem + if n*bytesPerElem != len(r) { + return nil, fmt.Errorf("invalid data length; expected multiple of %d", bytesPerElem) + } + + var out bytes.Buffer + w := bitio.NewWriter(&out) + for i := 0; i < n; i++ { + // read bytes + element := r[bytesPerElem*i : bytesPerElem*(i+1)] + // write bits + w.TryWriteBits(binary.BigEndian.Uint64(element[0:8]), packingSizeLastU64) + for j := 8; j < bytesPerElem; j += 8 { + w.TryWriteBits(binary.BigEndian.Uint64(element[j:j+8]), 64) + } + } + if w.TryError != nil { + return nil, fmt.Errorf("when writing to bitio.Writer: %w", w.TryError) + } + if err := w.Close(); err != nil { + return nil, fmt.Errorf("when closing bitio.Writer: %w", err) + } + + if !noTerminalSymbol { + // the last nonzero byte should be 0xff + outLen := out.Len() - 1 + for out.Bytes()[outLen] == 0 { + outLen-- + } + if out.Bytes()[outLen] != 0xff { + return nil, errors.New("invalid terminal symbol") + } + out.Truncate(outLen) + } + + return out.Bytes(), nil +} + +type packAlignSettings struct { + dataNbBits int + lastByteNbUnusedBits uint8 + noTerminalSymbol bool + additionalInput [][]byte +} + +func (s *packAlignSettings) initialize(length int, options ...packAlignOption) { + + for _, opt := range options { + opt(s) + } + + nbBytes := length + for _, data := range s.additionalInput { + nbBytes += len(data) + } + + if !s.noTerminalSymbol { + nbBytes++ + } + + s.dataNbBits = nbBytes*8 - int(s.lastByteNbUnusedBits) +} + +type packAlignOption func(*packAlignSettings) + +func NoTerminalSymbol() packAlignOption { + return func(o *packAlignSettings) { + o.noTerminalSymbol = true + } +} + +// PackAlignSize returns the size of the data when packed with PackAlign. +func PackAlignSize(length0, packingSize int, options ...packAlignOption) (n int) { + var s packAlignSettings + s.initialize(length0, options...) + + // we may need to add some bits to a and b to ensure we can process some blocks of 248 bits + extraBits := (packingSize - s.dataNbBits%packingSize) % packingSize + nbBits := s.dataNbBits + extraBits + + return (nbBits / packingSize) * ((packingSize + 7) / 8) +} + +func WithAdditionalInput(data ...[]byte) packAlignOption { + return func(o *packAlignSettings) { + o.additionalInput = append(o.additionalInput, data...) + } +} + +func WithLastByteNbUnusedBits(n uint8) packAlignOption { + if n > 7 { + panic("only 8 bits to a byte") + } + return func(o *packAlignSettings) { + o.lastByteNbUnusedBits = n + } +} + +// PackAlign writes a and b to w, aligned to fr.Element (bls12-377) boundary. +// It returns the length of the data written to w. +func PackAlign(w io.Writer, a []byte, packingSize int, options ...packAlignOption) (n int64, err error) { + + var s packAlignSettings + s.initialize(len(a), options...) + if !s.noTerminalSymbol && s.lastByteNbUnusedBits != 0 { + return 0, errors.New("terminal symbols with byte aligned input not yet supported") + } + + // we may need to add some bits to a and b to ensure we can process some blocks of packingSize bits + nbBits := (s.dataNbBits + (packingSize - 1)) / packingSize * packingSize + extraBits := nbBits - s.dataNbBits + + // padding will always be less than bytesPerElem bytes + bytesPerElem := (packingSize + 7) / 8 + packingSizeLastU64 := uint8(packingSize % 64) + if packingSizeLastU64 == 0 { + packingSizeLastU64 = 64 + } + bytePadding := (extraBits + 7) / 8 + buf := make([]byte, bytesPerElem, bytesPerElem+1) + + // the last nonzero byte is 0xff + if !s.noTerminalSymbol { + buf = append(buf, 0) + buf[0] = 0xff + } + + inReaders := make([]io.Reader, 2+len(s.additionalInput)) + inReaders[0] = bytes.NewReader(a) + for i, data := range s.additionalInput { + inReaders[i+1] = bytes.NewReader(data) + } + inReaders[len(inReaders)-1] = bytes.NewReader(buf[:bytePadding+1]) + + r := bitio.NewReader(io.MultiReader(inReaders...)) + + var tryWriteErr error + tryWrite := func(v uint64) { + if tryWriteErr == nil { + tryWriteErr = binary.Write(w, binary.BigEndian, v) + } + } + + for i := 0; i < nbBits/packingSize; i++ { + tryWrite(r.TryReadBits(packingSizeLastU64)) + for j := int(packingSizeLastU64); j < packingSize; j += 64 { + tryWrite(r.TryReadBits(64)) + } + } + + if tryWriteErr != nil { + return 0, fmt.Errorf("when writing to w: %w", tryWriteErr) + } + + if r.TryError != nil { + return 0, fmt.Errorf("when reading from multi-reader: %w", r.TryError) + } + + n1 := (nbBits / packingSize) * bytesPerElem + if n1 != PackAlignSize(len(a), packingSize, options...) { + return 0, errors.New("inconsistent PackAlignSize") + } + return int64(n1), nil +} + +// MiMCChecksumPackedData re-packs the data tightly into bls12-377 elements and computes the MiMC checksum. +// only supporting packing without a terminal symbol. Input with a terminal symbol will be interpreted in full padded length. +func MiMCChecksumPackedData(data []byte, inputPackingSize int, hashPackingOptions ...packAlignOption) ([]byte, error) { + dataNbBits := len(data) * 8 + if inputPackingSize%8 != 0 { + inputBytesPerElem := (inputPackingSize + 7) / 8 + dataNbBits = dataNbBits / inputBytesPerElem * inputPackingSize + var err error + if data, err = UnpackAlign(data, inputPackingSize, true); err != nil { + return nil, err + } + } + + lastByteNbUnusedBits := 8 - dataNbBits%8 + if lastByteNbUnusedBits == 8 { + lastByteNbUnusedBits = 0 + } + + var bb bytes.Buffer + packingOptions := make([]packAlignOption, len(hashPackingOptions)+1) + copy(packingOptions, hashPackingOptions) + packingOptions[len(packingOptions)-1] = WithLastByteNbUnusedBits(uint8(lastByteNbUnusedBits)) + if _, err := PackAlign(&bb, data, fr.Bits-1, packingOptions...); err != nil { + return nil, err + } + + hsh := hash.MIMC_BLS12_377.New() + hsh.Write(bb.Bytes()) + return hsh.Sum(nil), nil +} + +// DecodedBlockData is a wrapper struct storing the different fields of a block +// that we deserialize when decoding an ethereum block. +type DecodedBlockData struct { + // BlockHash stores the decoded block hash + BlockHash common.Hash + // Timestamp holds the Unix timestamp of the block in + Timestamp uint64 + // Froms stores the list of the sender address of every transaction + Froms []common.Address + // Txs stores the list of the decoded transactions. + Txs []types.Transaction +} diff --git a/prover/lib/compressor/libcompressor.go b/prover/lib/compressor/blob/libcompressor/libcompressor.go similarity index 100% rename from prover/lib/compressor/libcompressor.go rename to prover/lib/compressor/blob/libcompressor/libcompressor.go diff --git a/prover/lib/compressor/libcompressor.h b/prover/lib/compressor/blob/libcompressor/libcompressor.h similarity index 100% rename from prover/lib/compressor/libcompressor.h rename to prover/lib/compressor/blob/libcompressor/libcompressor.h diff --git a/prover/lib/compressor/blob/libdecompressor/libdecompressor.go b/prover/lib/compressor/blob/libdecompressor/libdecompressor.go new file mode 100644 index 000000000..720ecfa60 --- /dev/null +++ b/prover/lib/compressor/blob/libdecompressor/libdecompressor.go @@ -0,0 +1,74 @@ +package main + +import "C" + +import ( + "errors" + "sync" + "unsafe" + + decompressor "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" +) + +//go:generate go build -tags nocorset -ldflags "-s -w" -buildmode=c-shared -o libdecompressor.so libdecompressor.go +func main() {} + +var ( + dictStore dictionary.Store + lastError error + lock sync.Mutex // probably unnecessary if coordinator guarantees single-threaded access +) + +// Init initializes the decompressor. +// +//export Init +func Init() { + dictStore = dictionary.NewStore() +} + +// LoadDictionary loads a particular dictionary into the decompressor +// Returns true if the operation is successful, false otherwise. +// If false is returned, the Error() method will return a string describing the error. +// +//export LoadDictionary +func LoadDictionary(dictPath *C.char) bool { + lock.Lock() + defer lock.Unlock() + fPath := C.GoString(dictPath) + if err := dictStore.Load(fPath); err != nil { + lastError = err + return false + } + return true +} + +// Decompress processes a blob b and writes the resulting blocks in out, serialized in the format of +// prover/backend/ethereum. +// Returns the number of bytes in out, or -1 in case of failure +// If -1 is returned, the Error() method will return a string describing the error. +// +//export Decompress +func Decompress(blob *C.char, blobLength C.int, out *C.char, outMaxLength C.int) C.int { + + lock.Lock() + defer lock.Unlock() + + bGo := C.GoBytes(unsafe.Pointer(blob), blobLength) + + blocks, err := decompressor.DecompressBlob(bGo, dictStore) + if err != nil { + lastError = err + return -1 + } + + if len(blocks) > int(outMaxLength) { + lastError = errors.New("decoded blob does not fit in output buffer") + return -1 + } + + outSlice := unsafe.Slice((*byte)(unsafe.Pointer(out)), len(blocks)) + copy(outSlice, blocks) + + return C.int(len(blocks)) +} diff --git a/prover/lib/compressor/blob/v0/blob_maker.go b/prover/lib/compressor/blob/v0/blob_maker.go index 3cab0b350..54208bfa1 100644 --- a/prover/lib/compressor/blob/v0/blob_maker.go +++ b/prover/lib/compressor/blob/v0/blob_maker.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" "io" "os" "strings" @@ -39,6 +40,7 @@ type BlobMaker struct { limit int // maximum size of the compressed data compressor *lzss.Compressor // compressor used to compress the blob body dict []byte // dictionary used for compression + dictStore dictionary.Store header Header @@ -67,6 +69,10 @@ func NewBlobMaker(dataLimit int, dictPath string) (*BlobMaker, error) { } dict = lzss.AugmentDict(dict) blobMaker.dict = dict + blobMaker.dictStore, err = dictionary.SingletonStore(dict, 0) + if err != nil { + return nil, err + } dictChecksum := compress.ChecksumPaddedBytes(dict, len(dict), hash.MIMC_BLS12_377.New(), fr.Bits) copy(blobMaker.header.DictChecksum[:], dictChecksum) @@ -119,7 +125,7 @@ func (bm *BlobMaker) Written() int { func (bm *BlobMaker) Bytes() []byte { if bm.currentBlobLength > 0 { // sanity check that we can always decompress. - header, rawBlocks, _, err := DecompressBlob(bm.currentBlob[:bm.currentBlobLength], bm.dict) + header, rawBlocks, _, err := DecompressBlob(bm.currentBlob[:bm.currentBlobLength], bm.dictStore) if err != nil { var sbb strings.Builder fmt.Fprintf(&sbb, "invalid blob: %v\n", err) @@ -282,7 +288,7 @@ func (bm *BlobMaker) Equals(other *BlobMaker) bool { } // DecompressBlob decompresses a blob and returns the header and the blocks as they were compressed. -func DecompressBlob(b, dict []byte) (blobHeader *Header, rawBlocks []byte, blocks [][]byte, err error) { +func DecompressBlob(b []byte, dictStore dictionary.Store) (blobHeader *Header, rawBlocks []byte, blocks [][]byte, err error) { // UnpackAlign the blob b, err = UnpackAlign(b) if err != nil { @@ -296,11 +302,10 @@ func DecompressBlob(b, dict []byte) (blobHeader *Header, rawBlocks []byte, block return nil, nil, nil, fmt.Errorf("failed to read blob header: %w", err) } - // ensure the dict hash matches - { - if !bytes.Equal(compress.ChecksumPaddedBytes(dict, len(dict), hash.MIMC_BLS12_377.New(), fr.Bits), blobHeader.DictChecksum[:]) { - return nil, nil, nil, errors.New("invalid dict hash") - } + // retrieve dict + dict, err := dictStore.Get(blobHeader.DictChecksum[:], 0) + if err != nil { + return nil, nil, nil, err } b = b[read:] diff --git a/prover/lib/compressor/blob/v0/blob_maker_test.go b/prover/lib/compressor/blob/v0/blob_maker_test.go index e15696188..61f9b5845 100644 --- a/prover/lib/compressor/blob/v0/blob_maker_test.go +++ b/prover/lib/compressor/blob/v0/blob_maker_test.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" "io" "math/big" "math/rand" @@ -603,7 +604,11 @@ func decompressBlob(b []byte) ([][][]byte, error) { if err != nil { return nil, fmt.Errorf("can't read dict: %w", err) } - header, _, blocks, err := DecompressBlob(b, dict) + dictStore, err := dictionary.SingletonStore(dict, 0) + if err != nil { + return nil, err + } + header, _, blocks, err := DecompressBlob(b, dictStore) if err != nil { return nil, fmt.Errorf("can't decompress blob: %w", err) } diff --git a/prover/lib/compressor/blob/v0/encode.go b/prover/lib/compressor/blob/v0/encode.go index 490b401b8..7f4d22505 100644 --- a/prover/lib/compressor/blob/v0/encode.go +++ b/prover/lib/compressor/blob/v0/encode.go @@ -1,9 +1,13 @@ package v0 import ( + "bytes" "encoding/binary" + "errors" "fmt" "github.com/consensys/zkevm-monorepo/prover/backend/ethereum" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" + "github.com/ethereum/go-ethereum/common" "io" "github.com/ethereum/go-ethereum/core/types" @@ -97,3 +101,163 @@ func EncodeTxForCompression(tx *types.Transaction, w io.Writer) error { return nil } + +// DecodeBlockFromUncompressed inverts [EncodeBlockForCompression]. It is primarily meant for +// testing and ensuring the encoding is bijective. +func DecodeBlockFromUncompressed(r *bytes.Reader) (encode.DecodedBlockData, error) { + + /* + if err := binary.Write(w, binary.LittleEndian, block.Time()); err != nil { + return err + } + for _, tx := range block.Transactions() { + if err := EncodeTxForCompression(tx, w); err != nil { + return err + } + } + return nil + */ + + var decTimestamp uint64 + + if err := binary.Read(r, binary.BigEndian, &decTimestamp); err != nil { + return encode.DecodedBlockData{}, fmt.Errorf("could not decode timestamp: %w", err) + } + + decodedBlk := encode.DecodedBlockData{ + Timestamp: decTimestamp, + } + + for r.Len() != 0 { + var ( + tx types.Transaction + from common.Address + ) + if err := DecodeTxFromUncompressed(r, &tx, &from); err != nil { + return encode.DecodedBlockData{}, fmt.Errorf("could not decode transaction #%v: %w", len(decodedBlk.Txs), err) + } + decodedBlk.Froms = append(decodedBlk.Froms, from) + decodedBlk.Txs = append(decodedBlk.Txs, tx) + } + + return decodedBlk, nil +} + +func DecodeTxFromUncompressed(r *bytes.Reader, tx *types.Transaction, from *common.Address) error { + if _, err := r.Read(from[:]); err != nil { + return fmt.Errorf("could not read from address: %w", err) + } + + if err := ethereum.DecodeTxFromBytes(r, tx); err != nil { + return fmt.Errorf("could not deserialize transaction") + } + + firstByte, err := r.ReadByte() + if err != nil { + return fmt.Errorf("could not read the first byte: %w", err) + } + + switch { + case firstByte == types.DynamicFeeTxType: + return decodeDynamicFeeTx(r, tx, from) + case firstByte == types.AccessListTxType: + return decodeAccessListTx(r, tx, from) + // According to the RLP rule, `0xc0 + x` or `0xf7` indicates that the current + // item is a list and this is what's used to identify that the transaction is + // a legacy transaction or a EIP-155 transaction. + // + // Note that 0xc0 would indicate an empty list and thus be an invalid tx. + case firstByte > 0xc0: + // Set the byte-reader backward so that we can apply the rlp-decoder + // over it. + r.UnreadByte() + return decodeLegacyTx(r, tx, from) + } + + return fmt.Errorf("unexpected first byte: %x", firstByte) + + return nil +} + +func decodeLegacyTx(r *bytes.Reader, tx *types.Transaction, from *common.Address) (err error) { + decTx := []any{} + + if err := rlp.Decode(r, &decTx); err != nil { + return fmt.Errorf("could not rlp decode transaction: %w", err) + } + + if len(decTx) != 7 { + return fmt.Errorf("unexpected number of field") + } + + parsedTx := types.LegacyTx{} + err = errors.Join( + ethereum.TryCast(&parsedTx.Nonce, decTx[0], "nonce"), + ethereum.TryCast(&parsedTx.GasPrice, decTx[1], "gas-price"), + ethereum.TryCast(&parsedTx.Gas, decTx[2], "gas"), + ethereum.TryCast(from, decTx[3], "from"), + ethereum.TryCast(&parsedTx.To, decTx[4], "to"), + ethereum.TryCast(&parsedTx.Value, decTx[5], "value"), + ethereum.TryCast(&parsedTx.Data, decTx[6], "data"), + ) + + *tx = *types.NewTx(&parsedTx) + return err +} + +func decodeAccessListTx(r *bytes.Reader, tx *types.Transaction, from *common.Address) (err error) { + + decTx := []any{} + + if err := rlp.Decode(r, &decTx); err != nil { + return fmt.Errorf("could not rlp decode transaction: %w", err) + } + + if len(decTx) != 8 { + return fmt.Errorf("invalid number of field for a dynamic transaction") + } + + parsedTx := types.AccessListTx{} + err = errors.Join( + ethereum.TryCast(&parsedTx.Nonce, decTx[0], "nonce"), + ethereum.TryCast(&parsedTx.GasPrice, decTx[1], "gas-price"), + ethereum.TryCast(&parsedTx.Gas, decTx[2], "gas"), + ethereum.TryCast(from, decTx[3], "from"), + ethereum.TryCast(&parsedTx.To, decTx[4], "to"), + ethereum.TryCast(&parsedTx.Value, decTx[5], "value"), + ethereum.TryCast(&parsedTx.Data, decTx[6], "data"), + ethereum.TryCast(&parsedTx.AccessList, decTx[7], "access-list"), + ) + + *tx = *types.NewTx(&parsedTx) + return err +} + +func decodeDynamicFeeTx(r *bytes.Reader, tx *types.Transaction, from *common.Address) (err error) { + + decTx := []any{} + + if err := rlp.Decode(r, &decTx); err != nil { + return fmt.Errorf("could not rlp decode transaction: %w", err) + } + + if len(decTx) != 9 { + return fmt.Errorf("invalid number of field for a dynamic transaction") + } + + parsedTx := types.DynamicFeeTx{} + err = errors.Join( + ethereum.TryCast(&parsedTx.Nonce, decTx[0], "nonce"), + ethereum.TryCast(&parsedTx.GasTipCap, decTx[1], "gas-tip-cap"), + ethereum.TryCast(&parsedTx.GasFeeCap, decTx[2], "gas-fee-cap"), + ethereum.TryCast(&parsedTx.Gas, decTx[3], "gas"), + ethereum.TryCast(from, decTx[4], "from"), + ethereum.TryCast(&parsedTx.To, decTx[5], "to"), + ethereum.TryCast(&parsedTx.Value, decTx[6], "value"), + ethereum.TryCast(&parsedTx.Data, decTx[7], "data"), + ethereum.TryCast(&parsedTx.AccessList, decTx[8], "access-list"), + ) + *tx = *types.NewTx(&parsedTx) + return err + +} diff --git a/prover/lib/compressor/blob/v1/blob_maker.go b/prover/lib/compressor/blob/v1/blob_maker.go index cd45430e0..1638348d6 100644 --- a/prover/lib/compressor/blob/v1/blob_maker.go +++ b/prover/lib/compressor/blob/v1/blob_maker.go @@ -2,10 +2,10 @@ package v1 import ( "bytes" - "encoding/binary" "errors" "fmt" - "io" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" "os" "slices" "strings" @@ -13,10 +13,6 @@ import ( fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/sirupsen/logrus" - fr377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/hash" - "github.com/icza/bitio" - "github.com/consensys/compress/lzss" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rlp" @@ -40,6 +36,7 @@ type BlobMaker struct { Limit int // maximum size of the compressed data compressor *lzss.Compressor // compressor used to compress the blob body dict []byte // dictionary used for compression + dictStore dictionary.Store // dictionary store comprising only dict, used for decompression sanity checks Header Header @@ -68,8 +65,11 @@ func NewBlobMaker(dataLimit int, dictPath string) (*BlobMaker, error) { } dict = lzss.AugmentDict(dict) blobMaker.dict = dict + if blobMaker.dictStore, err = dictionary.SingletonStore(dict, 1); err != nil { + return nil, err + } - dictChecksum, err := MiMCChecksumPackedData(dict, 8) + dictChecksum, err := encode.MiMCChecksumPackedData(dict, 8) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func (bm *BlobMaker) Written() int { func (bm *BlobMaker) Bytes() []byte { if bm.currentBlobLength > 0 { // sanity check that we can always decompress. - header, rawBlocks, _, err := DecompressBlob(bm.currentBlob[:bm.currentBlobLength], bm.dict) + header, rawBlocks, _, err := DecompressBlob(bm.currentBlob[:bm.currentBlobLength], bm.dictStore) if err != nil { var sbb strings.Builder fmt.Fprintf(&sbb, "invalid blob: %v\n", err) @@ -192,13 +192,13 @@ func (bm *BlobMaker) Write(rlpBlock []byte, forceReset bool) (ok bool, err error } // check that the header + the compressed data fits in the blob - fitsInBlob := PackAlignSize(bm.buf.Len()+bm.compressor.Len(), fr381.Bits-1) <= bm.Limit + fitsInBlob := encode.PackAlignSize(bm.buf.Len()+bm.compressor.Len(), fr381.Bits-1) <= bm.Limit if !fitsInBlob { // first thing to check is if we bypass compression, would that fit? if bm.compressor.ConsiderBypassing() { // we can bypass compression and get a better ratio. // let's check if now we fit in the blob. - if PackAlignSize(bm.buf.Len()+bm.compressor.Len(), fr381.Bits-1) <= bm.Limit { + if encode.PackAlignSize(bm.buf.Len()+bm.compressor.Len(), fr381.Bits-1) <= bm.Limit { goto bypass } } @@ -222,7 +222,7 @@ bypass: // copy the compressed data to the blob bm.packBuffer.Reset() - n2, err := PackAlign(&bm.packBuffer, bm.buf.Bytes(), fr381.Bits-1, WithAdditionalInput(bm.compressor.Bytes())) + n2, err := encode.PackAlign(&bm.packBuffer, bm.buf.Bytes(), fr381.Bits-1, encode.WithAdditionalInput(bm.compressor.Bytes())) if err != nil { bm.compressor.Revert() bm.Header.removeLastBlock() @@ -265,9 +265,9 @@ func (bm *BlobMaker) Equals(other *BlobMaker) bool { } // DecompressBlob decompresses a blob and returns the header and the blocks as they were compressed. -func DecompressBlob(b, dict []byte) (blobHeader *Header, rawPayload []byte, blocks [][]byte, err error) { +func DecompressBlob(b []byte, dictStore dictionary.Store) (blobHeader *Header, rawPayload []byte, blocks [][]byte, err error) { // UnpackAlign the blob - b, err = UnpackAlign(b, fr381.Bits-1, false) + b, err = encode.UnpackAlign(b, fr381.Bits-1, false) if err != nil { return nil, nil, nil, err } @@ -278,15 +278,10 @@ func DecompressBlob(b, dict []byte) (blobHeader *Header, rawPayload []byte, bloc if err != nil { return nil, nil, nil, fmt.Errorf("failed to read blob header: %w", err) } - // ensure the dict hash matches - { - expectedDictChecksum, err := MiMCChecksumPackedData(dict, 8) - if err != nil { - return nil, nil, nil, err - } - if !bytes.Equal(expectedDictChecksum, blobHeader.DictChecksum[:]) { - return nil, nil, nil, errors.New("invalid dict hash") - } + // retrieve dict + dict, err := dictStore.Get(blobHeader.DictChecksum[:], 1) + if err != nil { + return nil, nil, nil, err } b = b[read:] @@ -318,210 +313,6 @@ func DecompressBlob(b, dict []byte) (blobHeader *Header, rawPayload []byte, bloc return blobHeader, rawPayload, blocks, nil } -// PackAlignSize returns the size of the data when packed with PackAlign. -func PackAlignSize(length0, packingSize int, options ...packAlignOption) (n int) { - var s packAlignSettings - s.initialize(length0, options...) - - // we may need to add some bits to a and b to ensure we can process some blocks of 248 bits - extraBits := (packingSize - s.dataNbBits%packingSize) % packingSize - nbBits := s.dataNbBits + extraBits - - return (nbBits / packingSize) * ((packingSize + 7) / 8) -} - -type packAlignSettings struct { - dataNbBits int - lastByteNbUnusedBits uint8 - noTerminalSymbol bool - additionalInput [][]byte -} - -type packAlignOption func(*packAlignSettings) - -func NoTerminalSymbol() packAlignOption { - return func(o *packAlignSettings) { - o.noTerminalSymbol = true - } -} - -func WithAdditionalInput(data ...[]byte) packAlignOption { - return func(o *packAlignSettings) { - o.additionalInput = append(o.additionalInput, data...) - } -} - -func WithLastByteNbUnusedBits(n uint8) packAlignOption { - if n > 7 { - panic("only 8 bits to a byte") - } - return func(o *packAlignSettings) { - o.lastByteNbUnusedBits = n - } -} - -func (s *packAlignSettings) initialize(length int, options ...packAlignOption) { - - for _, opt := range options { - opt(s) - } - - nbBytes := length - for _, data := range s.additionalInput { - nbBytes += len(data) - } - - if !s.noTerminalSymbol { - nbBytes++ - } - - s.dataNbBits = nbBytes*8 - int(s.lastByteNbUnusedBits) -} - -// PackAlign writes a and b to w, aligned to fr.Element (bls12-377) boundary. -// It returns the length of the data written to w. -func PackAlign(w io.Writer, a []byte, packingSize int, options ...packAlignOption) (n int64, err error) { - - var s packAlignSettings - s.initialize(len(a), options...) - if !s.noTerminalSymbol && s.lastByteNbUnusedBits != 0 { - return 0, errors.New("terminal symbols with byte aligned input not yet supported") - } - - // we may need to add some bits to a and b to ensure we can process some blocks of packingSize bits - nbBits := (s.dataNbBits + (packingSize - 1)) / packingSize * packingSize - extraBits := nbBits - s.dataNbBits - - // padding will always be less than bytesPerElem bytes - bytesPerElem := (packingSize + 7) / 8 - packingSizeLastU64 := uint8(packingSize % 64) - if packingSizeLastU64 == 0 { - packingSizeLastU64 = 64 - } - bytePadding := (extraBits + 7) / 8 - buf := make([]byte, bytesPerElem, bytesPerElem+1) - - // the last nonzero byte is 0xff - if !s.noTerminalSymbol { - buf = append(buf, 0) - buf[0] = 0xff - } - - inReaders := make([]io.Reader, 2+len(s.additionalInput)) - inReaders[0] = bytes.NewReader(a) - for i, data := range s.additionalInput { - inReaders[i+1] = bytes.NewReader(data) - } - inReaders[len(inReaders)-1] = bytes.NewReader(buf[:bytePadding+1]) - - r := bitio.NewReader(io.MultiReader(inReaders...)) - - var tryWriteErr error - tryWrite := func(v uint64) { - if tryWriteErr == nil { - tryWriteErr = binary.Write(w, binary.BigEndian, v) - } - } - - for i := 0; i < nbBits/packingSize; i++ { - tryWrite(r.TryReadBits(packingSizeLastU64)) - for j := int(packingSizeLastU64); j < packingSize; j += 64 { - tryWrite(r.TryReadBits(64)) - } - } - - if tryWriteErr != nil { - return 0, fmt.Errorf("when writing to w: %w", tryWriteErr) - } - - if r.TryError != nil { - return 0, fmt.Errorf("when reading from multi-reader: %w", r.TryError) - } - - n1 := (nbBits / packingSize) * bytesPerElem - if n1 != PackAlignSize(len(a), packingSize, options...) { - return 0, errors.New("inconsistent PackAlignSize") - } - return int64(n1), nil -} - -// UnpackAlign unpacks r (packed with PackAlign) and returns the unpacked data. -func UnpackAlign(r []byte, packingSize int, noTerminalSymbol bool) ([]byte, error) { - bytesPerElem := (packingSize + 7) / 8 - packingSizeLastU64 := uint8(packingSize % 64) - if packingSizeLastU64 == 0 { - packingSizeLastU64 = 64 - } - - n := len(r) / bytesPerElem - if n*bytesPerElem != len(r) { - return nil, fmt.Errorf("invalid data length; expected multiple of %d", bytesPerElem) - } - - var out bytes.Buffer - w := bitio.NewWriter(&out) - for i := 0; i < n; i++ { - // read bytes - element := r[bytesPerElem*i : bytesPerElem*(i+1)] - // write bits - w.TryWriteBits(binary.BigEndian.Uint64(element[0:8]), packingSizeLastU64) - for j := 8; j < bytesPerElem; j += 8 { - w.TryWriteBits(binary.BigEndian.Uint64(element[j:j+8]), 64) - } - } - if w.TryError != nil { - return nil, fmt.Errorf("when writing to bitio.Writer: %w", w.TryError) - } - if err := w.Close(); err != nil { - return nil, fmt.Errorf("when closing bitio.Writer: %w", err) - } - - if !noTerminalSymbol { - // the last nonzero byte should be 0xff - outLen := out.Len() - 1 - for out.Bytes()[outLen] == 0 { - outLen-- - } - if out.Bytes()[outLen] != 0xff { - return nil, errors.New("invalid terminal symbol") - } - out.Truncate(outLen) - } - - return out.Bytes(), nil -} - -// MiMCChecksumPackedData re-packs the data tightly into bls12-377 elements and computes the MiMC checksum. -// only supporting packing without a terminal symbol. Input with a terminal symbol will be interpreted in full padded length. -func MiMCChecksumPackedData(data []byte, inputPackingSize int, hashPackingOptions ...packAlignOption) ([]byte, error) { - dataNbBits := len(data) * 8 - if inputPackingSize%8 != 0 { - inputBytesPerElem := (inputPackingSize + 7) / 8 - dataNbBits = dataNbBits / inputBytesPerElem * inputPackingSize - var err error - if data, err = UnpackAlign(data, inputPackingSize, true); err != nil { - return nil, err - } - } - - lastByteNbUnusedBits := 8 - dataNbBits%8 - if lastByteNbUnusedBits == 8 { - lastByteNbUnusedBits = 0 - } - - var bb bytes.Buffer - packingOptions := make([]packAlignOption, len(hashPackingOptions)+1) - copy(packingOptions, hashPackingOptions) - packingOptions[len(packingOptions)-1] = WithLastByteNbUnusedBits(uint8(lastByteNbUnusedBits)) - if _, err := PackAlign(&bb, data, fr377.Bits-1, packingOptions...); err != nil { - return nil, err - } - - hsh := hash.MIMC_BLS12_377.New() - hsh.Write(bb.Bytes()) - return hsh.Sum(nil), nil -} - // WorstCompressedBlockSize returns the size of the given block, as compressed by an "empty" blob maker. // That is, with more context, blob maker could compress the block further, but this function // returns the maximum size that can be achieved. @@ -558,7 +349,7 @@ func (bm *BlobMaker) WorstCompressedBlockSize(rlpBlock []byte) (bool, int, error } // account for the padding - n = PackAlignSize(n, fr381.Bits-1, NoTerminalSymbol()) + n = encode.PackAlignSize(n, fr381.Bits-1, encode.NoTerminalSymbol()) return expandingBlock, n, nil } @@ -611,7 +402,7 @@ func (bm *BlobMaker) RawCompressedSize(data []byte) (int, error) { } // account for the padding - n = PackAlignSize(n, fr381.Bits-1, NoTerminalSymbol()) + n = encode.PackAlignSize(n, fr381.Bits-1, encode.NoTerminalSymbol()) return n, nil } diff --git a/prover/lib/compressor/blob/v1/blob_maker_test.go b/prover/lib/compressor/blob/v1/blob_maker_test.go index 7e726f517..0bc0ea3d0 100644 --- a/prover/lib/compressor/blob/v1/blob_maker_test.go +++ b/prover/lib/compressor/blob/v1/blob_maker_test.go @@ -9,6 +9,8 @@ import ( "encoding/hex" "errors" "fmt" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/dictionary" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" "math/big" "math/rand" "os" @@ -57,7 +59,8 @@ func testCompressorSingleSmallBatch(t *testing.T, blocks [][]byte) { dict, err := os.ReadFile(testDictPath) assert.NoError(t, err) - _, _, blocksBack, err := v1.DecompressBlob(bm.Bytes(), dict) + dictStore, err := dictionary.SingletonStore(dict, 1) + _, _, blocksBack, err := v1.DecompressBlob(bm.Bytes(), dictStore) assert.NoError(t, err) assert.Equal(t, len(blocks), len(blocksBack), "number of blocks should match") // TODO compare the blocks @@ -121,7 +124,7 @@ func assertBatchesConsistent(t *testing.T, raw, decoded [][]byte) { var block types.Block assert.NoError(t, rlp.Decode(bytes.NewReader(raw[i]), &block)) - blockBack, err := test_utils.DecodeBlockFromUncompressed(bytes.NewReader(decoded[i])) + blockBack, err := v1.DecodeBlockFromUncompressed(bytes.NewReader(decoded[i])) assert.NoError(t, err) assert.Equal(t, block.Time(), blockBack.Timestamp, "block time should match") } @@ -512,7 +515,11 @@ func decompressBlob(b []byte) ([][][]byte, error) { if err != nil { return nil, fmt.Errorf("can't read dict: %w", err) } - header, _, blocks, err := v1.DecompressBlob(b, dict) + dictStore, err := dictionary.SingletonStore(dict, 1) + if err != nil { + return nil, err + } + header, _, blocks, err := v1.DecompressBlob(b, dictStore) if err != nil { return nil, fmt.Errorf("can't decompress blob: %w", err) } @@ -641,10 +648,10 @@ func TestPack(t *testing.T) { runTest := func(s1, s2 []byte) { // pack them buf.Reset() - written, err := v1.PackAlign(&buf, s1, fr381.Bits-1, v1.WithAdditionalInput(s2)) + written, err := encode.PackAlign(&buf, s1, fr381.Bits-1, encode.WithAdditionalInput(s2)) assert.NoError(err, "pack should not generate an error") - assert.Equal(v1.PackAlignSize(len(s1)+len(s2), fr381.Bits-1), int(written), "written bytes should match expected PackAlignSize") - original, err := v1.UnpackAlign(buf.Bytes(), fr381.Bits-1, false) + assert.Equal(encode.PackAlignSize(len(s1)+len(s2), fr381.Bits-1), int(written), "written bytes should match expected PackAlignSize") + original, err := encode.UnpackAlign(buf.Bytes(), fr381.Bits-1, false) assert.NoError(err, "unpack should not generate an error") assert.Equal(s1, original[:len(s1)], "slices should match") diff --git a/prover/lib/compressor/blob/v1/encode.go b/prover/lib/compressor/blob/v1/encode.go index 1ce7355b9..bbac091f2 100644 --- a/prover/lib/compressor/blob/v1/encode.go +++ b/prover/lib/compressor/blob/v1/encode.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" "io" "github.com/consensys/zkevm-monorepo/prover/backend/ethereum" @@ -159,3 +160,54 @@ func PassRlpList(r *bytes.Reader) error { return nil } + +// DecodeBlockFromUncompressed inverts [EncodeBlockForCompression]. It is primarily meant for +// testing and ensuring the encoding is bijective. +func DecodeBlockFromUncompressed(r *bytes.Reader) (encode.DecodedBlockData, error) { + + var ( + decNumTxs uint16 + decTimestamp uint32 + blockHash common.Hash + ) + + if err := binary.Read(r, binary.BigEndian, &decNumTxs); err != nil { + return encode.DecodedBlockData{}, fmt.Errorf("could not decode nb txs: %w", err) + } + + if err := binary.Read(r, binary.BigEndian, &decTimestamp); err != nil { + return encode.DecodedBlockData{}, fmt.Errorf("could not decode timestamp: %w", err) + } + + if _, err := r.Read(blockHash[:]); err != nil { + return encode.DecodedBlockData{}, fmt.Errorf("could not read the block hash: %w", err) + } + + numTxs := int(decNumTxs) + decodedBlk := encode.DecodedBlockData{ + Froms: make([]common.Address, numTxs), + Txs: make([]types.Transaction, numTxs), + Timestamp: uint64(decTimestamp), + BlockHash: blockHash, + } + + for i := 0; i < int(decNumTxs); i++ { + if err := DecodeTxFromUncompressed(r, &decodedBlk.Txs[i], &decodedBlk.Froms[i]); err != nil { + return encode.DecodedBlockData{}, fmt.Errorf("could not decode transaction #%v: %w", i, err) + } + } + + return decodedBlk, nil +} + +func DecodeTxFromUncompressed(r *bytes.Reader, tx *types.Transaction, from *common.Address) (err error) { + if _, err := r.Read(from[:]); err != nil { + return fmt.Errorf("could not read from address: %w", err) + } + + if err := ethereum.DecodeTxFromBytes(r, tx); err != nil { + return fmt.Errorf("could not deserialize transaction") + } + + return nil +} diff --git a/prover/lib/compressor/blob/v1/encode_test.go b/prover/lib/compressor/blob/v1/encode_test.go index 732ea8b9a..a058284c2 100644 --- a/prover/lib/compressor/blob/v1/encode_test.go +++ b/prover/lib/compressor/blob/v1/encode_test.go @@ -6,6 +6,7 @@ import ( "bytes" "encoding/hex" "fmt" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" "testing" v1 "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1" @@ -33,29 +34,25 @@ func TestEncodeDecode(t *testing.T) { t.Fatalf("could not decode test RLP block: %s", err.Error()) } - var ( - buf = &bytes.Buffer{} - expected = test_utils.DecodedBlockData{ - BlockHash: block.Hash(), - Txs: make([]ethtypes.Transaction, len(block.Transactions())), - Timestamp: block.Time(), - } - ) + var buf bytes.Buffer + expected := encode.DecodedBlockData{ + BlockHash: block.Hash(), + Txs: make([]ethtypes.Transaction, len(block.Transactions())), + Timestamp: block.Time(), + } for i := range expected.Txs { expected.Txs[i] = *block.Transactions()[i] } - if err := v1.EncodeBlockForCompression(&block, buf); err != nil { + if err := v1.EncodeBlockForCompression(&block, &buf); err != nil { t.Fatalf("failed encoding the block: %s", err.Error()) } - var ( - encoded = buf.Bytes() - r = bytes.NewReader(encoded) - decoded, err = test_utils.DecodeBlockFromUncompressed(r) - size, errScan = v1.ScanBlockByteLen(encoded) - ) + encoded := buf.Bytes() + r := bytes.NewReader(encoded) + decoded, err := v1.DecodeBlockFromUncompressed(r) + size, errScan := v1.ScanBlockByteLen(encoded) assert.NoError(t, errScan, "error scanning the payload length") assert.NotZero(t, size, "scanned a block size of zero") @@ -138,7 +135,7 @@ func TestVectorDecode(t *testing.T) { var ( postPadded = append(b, postPad[:]...) r = bytes.NewReader(b) - _, errDec = test_utils.DecodeBlockFromUncompressed(r) + _, errDec = v1.DecodeBlockFromUncompressed(r) _, errScan = v1.ScanBlockByteLen(postPadded) ) diff --git a/prover/lib/compressor/blob/v1/test_utils/blob_maker_testing.go b/prover/lib/compressor/blob/v1/test_utils/blob_maker_testing.go index c437c8fc0..a281f1d5d 100644 --- a/prover/lib/compressor/blob/v1/test_utils/blob_maker_testing.go +++ b/prover/lib/compressor/blob/v1/test_utils/blob_maker_testing.go @@ -5,21 +5,17 @@ import ( "crypto/rand" "encoding/binary" "encoding/json" - "fmt" - "github.com/consensys/zkevm-monorepo/prover/backend/ethereum" - "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob" - v1 "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" - "os" - "path/filepath" - "strings" - "github.com/consensys/compress/lzss" fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/zkevm-monorepo/prover/backend/execution" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob" + "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/encode" + v1 "github.com/consensys/zkevm-monorepo/prover/lib/compressor/blob/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "os" + "path/filepath" + "strings" ) func GenTestBlob(t require.TestingT, maxNbBlocks int) []byte { @@ -78,7 +74,7 @@ func LoadTestBlocks(testDataDir string) (testBlocks [][]byte, err error) { return testBlocks, nil } -func RandIntn(n int) int { +func RandIntn(n int) int { // TODO @Tabaie remove var b [8]byte _, _ = rand.Read(b[:]) return int(binary.BigEndian.Uint64(b[:]) % uint64(n)) @@ -101,7 +97,7 @@ func EmptyBlob(t require.TestingT) []byte { assert.NoError(t, err) var bb bytes.Buffer - if _, err = v1.PackAlign(&bb, headerB.Bytes(), fr381.Bits-1, v1.WithAdditionalInput(compressor.Bytes())); err != nil { + if _, err = encode.PackAlign(&bb, headerB.Bytes(), fr381.Bits-1, encode.WithAdditionalInput(compressor.Bytes())); err != nil { panic(err) } return bb.Bytes() @@ -164,72 +160,6 @@ func TestBlocksAndBlobMaker(t require.TestingT) ([][]byte, *v1.BlobMaker) { return testBlocks, bm } -// DecodedBlockData is a wrapper struct storing the different fields of a block -// that we deserialize when decoding an ethereum block. -type DecodedBlockData struct { - // BlockHash stores the decoded block hash - BlockHash common.Hash - // Timestamp holds the Unix timestamp of the block in - Timestamp uint64 - // Froms stores the list of the sender address of every transaction - Froms []common.Address - // Txs stores the list of the decoded transactions. - Txs []types.Transaction -} - -// DecodeBlockFromUncompressed inverts [EncodeBlockForCompression]. It is primarily meant for -// testing and ensuring the encoding is bijective. -func DecodeBlockFromUncompressed(r *bytes.Reader) (DecodedBlockData, error) { - - var ( - decNumTxs uint16 - decTimestamp uint32 - blockHash common.Hash - ) - - if err := binary.Read(r, binary.BigEndian, &decNumTxs); err != nil { - return DecodedBlockData{}, fmt.Errorf("could not decode nb txs: %w", err) - } - - if err := binary.Read(r, binary.BigEndian, &decTimestamp); err != nil { - return DecodedBlockData{}, fmt.Errorf("could not decode timestamp: %w", err) - } - - if _, err := r.Read(blockHash[:]); err != nil { - return DecodedBlockData{}, fmt.Errorf("could not read the block hash: %w", err) - } - - var ( - numTxs = int(decNumTxs) - decodedBlk = DecodedBlockData{ - Froms: make([]common.Address, numTxs), - Txs: make([]types.Transaction, numTxs), - Timestamp: uint64(decTimestamp), - BlockHash: blockHash, - } - ) - - for i := 0; i < int(decNumTxs); i++ { - if err := DecodeTxFromUncompressed(r, &decodedBlk.Txs[i], &decodedBlk.Froms[i]); err != nil { - return DecodedBlockData{}, fmt.Errorf("could not decode transaction #%v: %w", i, err) - } - } - - return decodedBlk, nil -} - -func DecodeTxFromUncompressed(r *bytes.Reader, tx *types.Transaction, from *common.Address) (err error) { - if _, err := r.Read(from[:]); err != nil { - return fmt.Errorf("could not read from address: %w", err) - } - - if err := ethereum.DecodeTxFromBytes(r, tx); err != nil { - return fmt.Errorf("could not deserialize transaction") - } - - return nil -} - func GetDict(t require.TestingT) []byte { dict, err := blob.GetDict() require.NoError(t, err)