Skip to content

Commit

Permalink
Recover parked nodes from mem-cached layers with a small tree
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Aug 16, 2023
1 parent 6c996a2 commit 7eebb7d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 49 deletions.
80 changes: 36 additions & 44 deletions prover/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package prover

import (
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -144,10 +146,15 @@ func makeRecoveryProofTree(
return nil, nil, fmt.Errorf("layer 0 cache file is missing")
}

var topLayer uint
parkedNodesMap := make(map[uint][]byte)

// Validate structure.
for layer, file := range layersFiles {
if layer > topLayer {
topLayer = layer
}

readWriter, err := layerFactory(uint(layer))
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -211,31 +218,16 @@ func makeRecoveryProofTree(
parkedNodes = append(parkedNodes, nil)
}
}

// recover parked nodes from mem-cached layers
fileLayers := len(parkedNodes)
memLayerFactory := NewReadWriterMetaFactory(
treeCfg.MinMemoryLayer,
treeCfg.Datadir,
treeCfg.FileWriterBufSize,
).GetFactory()
for layer := fileLayers; ; layer++ {
numNodes := nextLeafID >> layer
if numNodes == 0 {
break
}
if numNodes%2 == 0 {
parkedNodes = append(parkedNodes, nil)
} else {
node, err := GetNode(memLayerFactory, layer, numNodes-1, merkleHashFunc)
if err != nil {
return nil, nil, fmt.Errorf("recovering parked node in layer %d: %w", layer, err)
}
logging.FromContext(ctx).
Info("recovered (mem) parked node", zap.Uint("layer", uint(layer)), zap.String("node", fmt.Sprintf("%X", node)))
parkedNodes = append(parkedNodes, node)
}
layerReader, err := layerFactory(topLayer)
if err != nil {
return nil, nil, err
}
defer layerReader.Close()
memCachedParkedNodes, err := recoverMemCachedParkedNodes(layerReader, merkleHashFunc)
if err != nil {
return nil, nil, fmt.Errorf("Recoveing parked nodes from top layer of disk-cache: %w", err)

Check failure on line 228 in prover/prover.go

View workflow job for this annotation

GitHub Actions / quicktests

error strings should not be capitalized (ST1005)
}
parkedNodes = append(parkedNodes, memCachedParkedNodes...)

logging.FromContext(ctx).
Info("all recovered parked nodes", zap.Array("nodes", zapcore.ArrayMarshalerFunc(func(enc zapcore.ArrayEncoder) error {
Expand Down Expand Up @@ -403,33 +395,33 @@ func CalcTreeRoot(leaves [][]byte) ([]byte, error) {
return tree.Root(), nil
}

// GetNode returns the node at the given index in the given layer.
// It will recursively calculate the node if it is not already cached.
func GetNode(layerFactory mshared.LayerFactory, layer int, index uint64, hash merkle.HashFunc) ([]byte, error) {
readWriter, err := layerFactory(uint(layer))
// build a small tree with the nodes from the top layer of the cache as leafs.
// this tree will be used to get parked nodes for the merkle tree.
func recoverMemCachedParkedNodes(layerReader mshared.LayerReader, merkleHashFunc merkle.HashFunc) ([][]byte, error) {
tmpDir, err := os.MkdirTemp(os.TempDir(), "poet-recovery-tree")
if err != nil {
return nil, err
}
defer readWriter.Close()

// Try to obtain from cache.
if err := readWriter.Seek(index); err == nil {
return readWriter.ReadNext()
}
defer os.RemoveAll(tmpDir)

if layer == 0 {
return nil, fmt.Errorf("cannot recreate leaf at index %d", index)
}

left, err := GetNode(layerFactory, layer-1, index*2, hash)
tree, err := merkle.NewTreeBuilder().WithHashFunc(merkleHashFunc).Build()
if err != nil {
return nil, err
}

right, err := GetNode(layerFactory, layer-1, index*2+1, hash)
if err != nil {
return nil, err
// append nodes as leafs
for {
node, err := layerReader.ReadNext()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, fmt.Errorf("reading node from top layer of disk-cache: %w", err)
}
if err := tree.AddLeaf(node); err != nil {
return nil, fmt.Errorf("adding node to small tree: %w", err)
}
}

return hash(nil, left, right), nil
// the first parked node is for the leaves from the layerReader.
return tree.GetParkedNodes(nil)[1:], nil
}
3 changes: 1 addition & 2 deletions service/round.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func (r *round) persistExecution(
return r.saveState()
}

func (r *round) recoverExecution(ctx context.Context, end time.Time, minMemoryLayer, fileWriterBufSize uint) error {
func (r *round) recoverExecution(ctx context.Context, end time.Time, fileWriterBufSize uint) error {
logger := logging.FromContext(ctx).With(zap.String("round", r.ID))

started := time.Now()
Expand Down Expand Up @@ -263,7 +263,6 @@ func (r *round) recoverExecution(ctx context.Context, end time.Time, minMemoryLa
prover.TreeConfig{
Datadir: r.datadir,
FileWriterBufSize: fileWriterBufSize,
MinMemoryLayer: minMemoryLayer,
},
hash.GenLabelHashFunc(r.execution.Statement),
hash.GenMerkleHashFunc(r.execution.Statement),
Expand Down
4 changes: 2 additions & 2 deletions service/round_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ func TestRound_ExecutionRecovery(t *testing.T) {

ctx, stop := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer stop()
req.ErrorIs(round.recoverExecution(ctx, time.Now().Add(time.Hour), 2, 0), context.DeadlineExceeded)
req.ErrorIs(round.recoverExecution(ctx, time.Now().Add(time.Hour), 0), context.DeadlineExceeded)
req.NoError(round.teardown(context.Background(), false))
}

Expand All @@ -371,7 +371,7 @@ func TestRound_ExecutionRecovery(t *testing.T) {
req.Equal(len(challenges), numChallenges(round))
req.NoError(round.loadState())

req.NoError(round.recoverExecution(context.Background(), time.Now().Add(400*time.Millisecond), 2, 0))
req.NoError(round.recoverExecution(context.Background(), time.Now().Add(400*time.Millisecond), 0))
validateProof(t, round.execution)
req.NoError(round.teardown(context.Background(), true))
}
Expand Down
2 changes: 1 addition & 1 deletion service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func (s *Service) loop(ctx context.Context, roundToResume *round) error {
eg.Go(func() error {
unlock := lockOSThread(ctx, roundTidFile)
defer unlock()
err := round.recoverExecution(ctx, end, s.minMemoryLayer, s.cfg.TreeFileBufferSize)
err := round.recoverExecution(ctx, end, s.cfg.TreeFileBufferSize)
roundResults <- roundResult{round: round, err: err}
return nil
})
Expand Down

0 comments on commit 7eebb7d

Please sign in to comment.