diff --git a/cmd/bench/bench.go b/cmd/bench/bench.go index 7b3a386a..6a5ecc5e 100644 --- a/cmd/bench/bench.go +++ b/cmd/bench/bench.go @@ -42,6 +42,8 @@ func main() { securityParam := shared.T tempdir, _ := os.MkdirTemp("", "poet-test") + defer os.RemoveAll(tempdir) + proofGenStarted := time.Now() end := proofGenStarted.Add(cfg.Duration) leafs, merkleProof, err := prover.GenerateProofWithoutPersistency( diff --git a/go.mod b/go.mod index f7e1d7b3..72bdd7de 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/zeebo/blake3 v0.2.3 go.uber.org/mock v0.2.0 go.uber.org/zap v1.25.0 + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 golang.org/x/sync v0.3.0 golang.org/x/sys v0.12.0 google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d diff --git a/go.sum b/go.sum index 94a9fbea..22eb47e3 100644 --- a/go.sum +++ b/go.sum @@ -258,6 +258,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= diff --git a/hashfunc_test.go b/hashfunc_test.go deleted file mode 100644 index 9927d60d..00000000 --- a/hashfunc_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package main - -import ( - "bytes" - "fmt" - "testing" - "time" - - "github.com/spacemeshos/sha256-simd" -) - -func BenchmarkSha256(t *testing.B) { - buff := bytes.Buffer{} - buff.Write([]byte("Seed data goes here")) - out := [32]byte{} - n := (uint64(1) << 20) * 101 - - fmt.Printf("Computing %d serial sha-256s...\n", n) - - t1 := time.Now() - - for i := uint64(0); i < n; i++ { - out = sha256.Sum256(buff.Bytes()) - buff.Reset() - buff.Write(out[:]) - } - - e := time.Since(t1) - r := uint64(float64(n) / e.Seconds()) - fmt.Printf("Final hash: %x. Running time: %s secs. Hash-rate: %d hashes-per-sec\n", buff.Bytes(), e, r) -} diff --git a/poet.go b/poet.go index 5b84ffbb..3d7f9003 100644 --- a/poet.go +++ b/poet.go @@ -44,11 +44,8 @@ func poetMain() (err error) { if err != nil { return err } + server.SetupConfig(cfg) - cfg, err = server.SetupConfig(cfg) - if err != nil { - return err - } // Finally, parse the remaining command line options again to ensure // they take precedence. cfg, err = server.ParseFlags(cfg) diff --git a/prover/layer_factory.go b/prover/layer_factory.go new file mode 100644 index 00000000..d7d97cf4 --- /dev/null +++ b/prover/layer_factory.go @@ -0,0 +1,27 @@ +package prover + +import ( + "fmt" + "path/filepath" + + "github.com/spacemeshos/merkle-tree/cache" + "github.com/spacemeshos/merkle-tree/cache/readwriters" +) + +// GetLayerFactory creates a merkle LayerFactory. +// The minMemoryLayer determines the threshold below which layers are saved on-disk, while layers equal and above - +// in-memory. +func GetLayerFactory(minMemoryLayer uint, datadir string, fileWriterBufSize uint) cache.LayerFactory { + return func(layerHeight uint) (cache.LayerReadWriter, error) { + if layerHeight < minMemoryLayer { + fileName := filepath.Join(datadir, fmt.Sprintf("layercache_%d.bin", layerHeight)) + readWriter, err := readwriters.NewFileReadWriter(fileName, int(fileWriterBufSize)) + if err != nil { + return nil, err + } + + return readWriter, nil + } + return &readwriters.SliceReadWriter{}, nil + } +} diff --git a/prover/prover.go b/prover/prover.go index 15964033..f1361a5d 100644 --- a/prover/prover.go +++ b/prover/prover.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "math" "os" "path/filepath" "regexp" @@ -43,17 +44,14 @@ type TreeConfig struct { type persistFunc func(ctx context.Context, treeCache *cache.Writer, nextLeafId uint64) error -var persist persistFunc = func(context.Context, *cache.Writer, uint64) error { return nil } - -// GenerateProof computes the PoET DAG, uses Fiat-Shamir to derive a challenge from the Merkle root and generates a -// Merkle proof using the challenge and the DAG. +// GenerateProof generates the Proof of Sequential Work. It stops when the given deadline is reached. func GenerateProof( ctx context.Context, leavesCounter prometheus.Counter, treeCfg TreeConfig, labelHashFunc func(data []byte) []byte, merkleHashFunc merkle.HashFunc, - limit time.Time, + deadline time.Time, securityParam uint8, persist persistFunc, ) (uint64, *shared.MerkleProof, error) { @@ -63,17 +61,17 @@ func GenerateProof( } defer treeCache.Close() - return generateProof(ctx, leavesCounter, labelHashFunc, tree, treeCache, limit, 0, securityParam, persist) + return generateProof(ctx, leavesCounter, labelHashFunc, tree, treeCache, deadline, 0, securityParam, persist) } -// GenerateProofRecovery recovers proof generation, from a given 'nextLeafID' and for a given 'parkedNodes' snapshot. +// GenerateProofRecovery recovers proof generation, from a given 'nextLeafID'. func GenerateProofRecovery( ctx context.Context, leavesCounter prometheus.Counter, treeCfg TreeConfig, labelHashFunc func(data []byte) []byte, merkleHashFunc merkle.HashFunc, - limit time.Time, + deadline time.Time, securityParam uint8, nextLeafID uint64, persist persistFunc, @@ -84,35 +82,50 @@ func GenerateProofRecovery( } defer treeCache.Close() - return generateProof(ctx, leavesCounter, labelHashFunc, tree, treeCache, limit, nextLeafID, securityParam, persist) + return generateProof( + ctx, + leavesCounter, + labelHashFunc, + tree, + treeCache, + deadline, + nextLeafID, + securityParam, + persist, + ) } -// GenerateProofWithoutPersistency calls GenerateProof with disabled persistency functionality -// and potential soft/hard-shutdown recovery. -// Meant to be used for testing purposes only. Doesn't expose metrics too. +// GenerateProofWithoutPersistency calls GenerateProof with disabled persistency functionality. +// Tree recovery will not be possible. Meant to be used for testing purposes only. +// It doesn't expose metrics too. func GenerateProofWithoutPersistency( ctx context.Context, treeCfg TreeConfig, labelHashFunc func(data []byte) []byte, merkleHashFunc merkle.HashFunc, - limit time.Time, + deadline time.Time, securityParam uint8, ) (uint64, *shared.MerkleProof, error) { leavesCounter := prometheus.NewCounter(prometheus.CounterOpts{}) - return GenerateProof(ctx, leavesCounter, treeCfg, labelHashFunc, merkleHashFunc, limit, securityParam, persist) + return GenerateProof( + ctx, + leavesCounter, + treeCfg, + labelHashFunc, + merkleHashFunc, + deadline, + securityParam, + func(context.Context, *cache.Writer, uint64) error { return nil }, + ) } func makeProofTree(treeCfg TreeConfig, merkleHashFunc merkle.HashFunc) (*merkle.Tree, *cache.Writer, error) { - if treeCfg.MinMemoryLayer < LowestMerkleMinMemoryLayer { - treeCfg.MinMemoryLayer = LowestMerkleMinMemoryLayer - } - metaFactory := NewReadWriterMetaFactory(treeCfg.MinMemoryLayer, treeCfg.Datadir, treeCfg.FileWriterBufSize) - + minMemoryLayer := max(treeCfg.MinMemoryLayer, LowestMerkleMinMemoryLayer) treeCache := cache.NewWriter( cache.Combine( cache.SpecificLayersPolicy(map[uint]bool{0: true}), cache.MinHeightPolicy(MerkleMinCacheLayer)), - metaFactory.GetFactory(), + GetLayerFactory(minMemoryLayer, treeCfg.Datadir, treeCfg.FileWriterBufSize), ) tree, err := merkle.NewTreeBuilder().WithHashFunc(merkleHashFunc).WithCacheWriter(treeCache).Build() @@ -130,8 +143,7 @@ func makeRecoveryProofTree( nextLeafID uint64, ) (*cache.Writer, *merkle.Tree, error) { // Don't use memory cache. Just utilize the existing files cache. - maxUint := ^uint(0) - layerFactory := NewReadWriterMetaFactory(maxUint, treeCfg.Datadir, treeCfg.FileWriterBufSize).GetFactory() + layerFactory := GetLayerFactory(math.MaxUint, treeCfg.Datadir, treeCfg.FileWriterBufSize) layersFiles, err := getLayersFiles(treeCfg.Datadir) if err != nil { @@ -139,8 +151,7 @@ func makeRecoveryProofTree( } // Validate that layer 0 exists. - _, ok := layersFiles[0] - if !ok { + if _, ok := layersFiles[0]; !ok { return nil, nil, fmt.Errorf("layer 0 cache file is missing") } @@ -149,11 +160,9 @@ func makeRecoveryProofTree( // Validate structure. for layer, file := range layersFiles { - if layer > topLayer { - topLayer = layer - } + topLayer = max(topLayer, layer) - readWriter, err := layerFactory(uint(layer)) + readWriter, err := layerFactory(layer) if err != nil { return nil, nil, err } @@ -225,7 +234,7 @@ func makeRecoveryProofTree( parkedNodes = append(parkedNodes, memCachedParkedNodes...) logging.FromContext(ctx). - Info("recovered parked nodes", zap.Array("nodes", zapcore.ArrayMarshalerFunc(func(enc zapcore.ArrayEncoder) error { + Debug("recovered parked nodes", zap.Array("nodes", zapcore.ArrayMarshalerFunc(func(enc zapcore.ArrayEncoder) error { for _, node := range parkedNodes { enc.AppendString(fmt.Sprintf("%X", node)) } @@ -384,29 +393,13 @@ func getLayersFiles(datadir string) (map[uint]string, error) { return files, nil } -// Calculate the root of a Merkle Tree with given leaves. -func CalcTreeRoot(leaves [][]byte) ([]byte, error) { - tree, err := merkle.NewTreeBuilder().WithHashFunc(shared.HashMembershipTreeNode).Build() - if err != nil { - return nil, fmt.Errorf("failed to generate tree: %w", err) - } - for _, member := range leaves { - err := tree.AddLeaf(member) - if err != nil { - return nil, fmt.Errorf("failed to add leaf: %w", err) - } - } - return tree.Root(), nil -} - // build a small tree with the nodes from the top layer of the cache as leafs. // this tree will be used to get parked nodes for the merkle tree. func recoverMemCachedParkedNodes( layerReader mshared.LayerReader, merkleHashFunc merkle.HashFunc, ) ([][]byte, mshared.CacheReader, error) { - recoveryTreelayerFactory := NewReadWriterMetaFactory(0, "", 0).GetFactory() - recoveryTreeCache := cache.NewWriter(func(uint) bool { return true }, recoveryTreelayerFactory) + recoveryTreeCache := cache.NewWriter(func(uint) bool { return true }, GetLayerFactory(0, "", 0)) tree, err := merkle.NewTreeBuilder().WithHashFunc(merkleHashFunc).WithCacheWriter(recoveryTreeCache).Build() if err != nil { diff --git a/prover/readwritermetafactory.go b/prover/readwritermetafactory.go deleted file mode 100644 index 3abb2b62..00000000 --- a/prover/readwritermetafactory.go +++ /dev/null @@ -1,73 +0,0 @@ -package prover - -import ( - "errors" - "fmt" - "os" - "path/filepath" - - "github.com/spacemeshos/merkle-tree/cache" - "github.com/spacemeshos/merkle-tree/cache/readwriters" -) - -// ReadWriterMetaFactory generates Merkle LayerFactory functions. The functions it creates generate file read-writers -// starting from the base layer and up to minMemoryLayer-1. From minMemoryLayer and up the functions generate slice -// read-writers. -// The MetaFactory tracks the files it creates and removes them when Cleanup() is called. -type ReadWriterMetaFactory struct { - minMemoryLayer uint - datadir string - filesCreated map[string]bool - fileWriterBufSize uint -} - -// NewReadWriterMetaFactory returns a new ReadWriterMetaFactory. minMemoryLayer determines -// the threshold in which lower layers caching is done on-disk, while from this layer up -- in-memory. -func NewReadWriterMetaFactory(minMemoryLayer uint, datadir string, fileWriterBufSize uint) *ReadWriterMetaFactory { - return &ReadWriterMetaFactory{ - minMemoryLayer: minMemoryLayer, - datadir: datadir, - filesCreated: make(map[string]bool), - fileWriterBufSize: fileWriterBufSize, - } -} - -// GetFactory creates a Merkle LayerFactory function. -func (mf *ReadWriterMetaFactory) GetFactory() cache.LayerFactory { - return func(layerHeight uint) (cache.LayerReadWriter, error) { - if layerHeight < mf.minMemoryLayer { - fileName, err := mf.makeFileName(layerHeight) - if err != nil { - return nil, err - } - - readWriter, err := readwriters.NewFileReadWriter(fileName, int(mf.fileWriterBufSize)) - if err != nil { - return nil, err - } - - mf.filesCreated[fileName] = true - return readWriter, nil - } - return &readwriters.SliceReadWriter{}, nil - } -} - -// Cleanup removes the files that were created by the LayerFactory functions generated by this MetaFactory. -func (mf *ReadWriterMetaFactory) Cleanup() error { - var result error - failedRemovals := make(map[string]bool) - for filename := range mf.filesCreated { - err := os.Remove(filename) - if err != nil { - result = errors.Join(err, fmt.Errorf("could not remove temp file %v: %w", filename, err)) - failedRemovals[filename] = true - } - } - mf.filesCreated = failedRemovals - return result -} - -func (mf *ReadWriterMetaFactory) makeFileName(layer uint) (string, error) { - return filepath.Join(mf.datadir, fmt.Sprintf("layercache_%d.bin", layer)), nil -} diff --git a/rpc/rpcserver_test.go b/rpc/rpcserver_test.go index 0c940b72..7c286e79 100644 --- a/rpc/rpcserver_test.go +++ b/rpc/rpcserver_test.go @@ -11,13 +11,11 @@ import ( api "github.com/spacemeshos/poet/release/proto/go/rpc/api/v1" "github.com/spacemeshos/poet/rpc" - "github.com/spacemeshos/poet/server" ) func Test_Submit_DoesNotPanicOnMissingPubKey(t *testing.T) { // Arrange - cfg := server.DefaultConfig() - sv := rpc.NewServer(nil, nil, cfg.Round.PhaseShift, cfg.Round.CycleGap) + sv := rpc.NewServer(nil, nil, 0, 0) // Act in := &api.SubmitRequest{} @@ -35,8 +33,7 @@ func Test_Submit_DoesNotPanicOnMissingPubKey(t *testing.T) { func Test_Submit_DoesNotPanicOnMissingSignature(t *testing.T) { // Arrange - cfg := server.DefaultConfig() - sv := rpc.NewServer(nil, nil, cfg.Round.PhaseShift, cfg.Round.CycleGap) + sv := rpc.NewServer(nil, nil, 0, 0) pub, _, err := ed25519.GenerateKey(nil) require.NoError(t, err) diff --git a/server/config.go b/server/config.go index 9d4f2442..0f05ccea 100644 --- a/server/config.go +++ b/server/config.go @@ -1,16 +1,10 @@ -// Copyright (c) 2013-2017 The btcsuite developers -// Copyright (c) 2015-2016 The Decred developers -// Copyright (c) 2017-2023 The Spacemesh developers - package server import ( "context" "fmt" "os" - "os/user" "path/filepath" - "strings" "time" "github.com/jessevdk/go-flags" @@ -35,10 +29,6 @@ const ( defaultCycleGap = 10 * time.Second ) -// Config defines the configuration options for poet. -// -// See loadConfig for further details regarding the -// configuration loading+parsing process. type Config struct { Genesis Genesis `long:"genesis-time" description:"Genesis timestamp in RFC3339 format"` PoetDir string `long:"poetdir" description:"The base directory that contains poet's data, logs, configuration file, etc."` @@ -125,8 +115,8 @@ func ReadConfigFile(cfg *Config) (*Config, error) { return cfg, nil } -// SetupConfig expands paths and initializes filesystem. -func SetupConfig(cfg *Config) (*Config, error) { +// SetupConfig adjusts the paths in the config to be relative to the poetdir. +func SetupConfig(cfg *Config) { // If the provided poet directory is not the default, we'll modify the // path to all of the files and directories that will live within it. defaultCfg := DefaultConfig() @@ -141,45 +131,6 @@ func SetupConfig(cfg *Config) (*Config, error) { cfg.DbDir = filepath.Join(cfg.PoetDir, defaultDbDirName) } } - - // Create the poet directory if it doesn't already exist. - if err := os.MkdirAll(cfg.PoetDir, 0o700); err != nil { - return nil, fmt.Errorf("failed to create %v: %w", cfg.PoetDir, err) - } - - // As soon as we're done parsing configuration options, ensure all paths - // to directories and files are cleaned and expanded before attempting - // to use them later on. - cfg.DataDir = cleanAndExpandPath(cfg.DataDir) - cfg.LogDir = cleanAndExpandPath(cfg.LogDir) - - return cfg, nil -} - -// cleanAndExpandPath expands environment variables and leading ~ in the -// passed path, cleans the result, and returns it. -// This function is taken from https://github.com/btcsuite/btcd -func cleanAndExpandPath(path string) string { - if path == "" { - return "" - } - - // Expand initial ~ to OS specific home directory. - if strings.HasPrefix(path, "~") { - var homeDir string - user, err := user.Current() - if err == nil { - homeDir = user.HomeDir - } else { - homeDir = os.Getenv("HOME") - } - - path = strings.Replace(path, "~", homeDir, 1) - } - - // NOTE: The os.ExpandEnv doesn't work with Windows-style %VARIABLE%, - // but the variables can still be expanded via POSIX-style $VARIABLE. - return filepath.Clean(os.ExpandEnv(path)) } type RoundConfig struct { diff --git a/server/server.go b/server/server.go index febb498f..c52bd56f 100644 --- a/server/server.go +++ b/server/server.go @@ -66,7 +66,6 @@ func New(ctx context.Context, cfg Config) (*Server, error) { if err != nil { return nil, err } - restListener, err := net.Listen(addr.Network(), addr.String()) if err != nil { return nil, fmt.Errorf("failed to listen: %v", err) @@ -157,17 +156,22 @@ func (s *Server) Start(ctx context.Context) error { logger := logging.FromContext(ctx) - // grpc metrics metrics := grpc_prometheus.NewServerMetrics( grpc_prometheus.WithServerHandlingTimeHistogram( grpc_prometheus.WithHistogramBuckets(prometheus.ExponentialBuckets(0.001, 2, 16)), ), ) - // Initialize and register the implementation of gRPC interface - var grpcServer *grpc.Server - var proxyRegstr []func(context.Context, *proxy.ServeMux, string, []grpc.DialOption) error - options := []grpc.ServerOption{ + serverGroup.Go(func() error { + return s.reg.Run(ctx) + }) + + serverGroup.Go(func() error { + return s.svc.Run(ctx) + }) + + rpcServer := rpc.NewServer(s.svc, s.reg, s.cfg.Round.PhaseShift, s.cfg.Round.CycleGap) + grpcServer := grpc.NewServer( grpc.UnaryInterceptor(grpcmw.ChainUnaryServer( loggerInterceptor(logger), metrics.UnaryServerInterceptor(), @@ -182,22 +186,9 @@ func (s *Server) Start(ctx context.Context) error { Time: time.Minute, Timeout: time.Minute * 3, }), - } - - serverGroup.Go(func() error { - return s.reg.Run(ctx) - }) - - serverGroup.Go(func() error { - return s.svc.Run(ctx) - }) - - rpcServer := rpc.NewServer(s.svc, s.reg, s.cfg.Round.PhaseShift, s.cfg.Round.CycleGap) - grpcServer = grpc.NewServer(options...) + ) api.RegisterPoetServiceServer(grpcServer, rpcServer) - proxyRegstr = append(proxyRegstr, api.RegisterPoetServiceHandlerFromEndpoint) - reflection.Register(grpcServer) metrics.InitializeMetrics(grpcServer) prometheus.Register(metrics) @@ -210,16 +201,16 @@ func (s *Server) Start(ctx context.Context) error { // Start the REST proxy for the gRPC server above. mux := proxy.NewServeMux() - for _, r := range proxyRegstr { - err := r( - ctx, - mux, - s.rpcListener.Addr().String(), - []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, - ) - if err != nil { - return err - } + err := api.RegisterPoetServiceHandlerFromEndpoint( + ctx, + mux, + s.rpcListener.Addr().String(), + []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + }, + ) + if err != nil { + return err } server := &http.Server{Handler: mux, ReadHeaderTimeout: time.Second * 5} @@ -234,8 +225,10 @@ func (s *Server) Start(ctx context.Context) error { // Wait for the server to shut down gracefully <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + err = server.Shutdown(shutdownCtx) grpcServer.GracefulStop() - err := server.Shutdown(context.Background()) return errors.Join(err, serverGroup.Wait()) } diff --git a/server/server_test.go b/server/server_test.go index e9fb4aa6..bea6cebe 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -21,7 +21,6 @@ import ( "github.com/spacemeshos/poet/hash" "github.com/spacemeshos/poet/logging" - "github.com/spacemeshos/poet/prover" "github.com/spacemeshos/poet/registration" api "github.com/spacemeshos/poet/release/proto/go/rpc/api/v1" "github.com/spacemeshos/poet/server" @@ -35,8 +34,7 @@ func spawnPoet(ctx context.Context, t *testing.T, cfg server.Config) (*server.Se t.Helper() req := require.New(t) - _, err := server.SetupConfig(&cfg) - req.NoError(err) + server.SetupConfig(&cfg) srv, err := server.New(ctx, cfg) req.NoError(err) @@ -263,8 +261,8 @@ func TestSubmitAndGetProof(t *testing.T) { ProofNodes: proof.Proof.Proof.ProofNodes, } - root, err := prover.CalcTreeRoot(proof.Proof.Members) - req.NoError(err) + // single-element tree: the leaf is the root + root := proof.Proof.Members[0] labelHashFunc := hash.GenLabelHashFunc(root) merkleHashFunc := hash.GenMerkleHashFunc(root) @@ -471,8 +469,7 @@ func TestLoadSubmits(t *testing.T) { cfg.Round.PhaseShift = time.Minute cfg.RawRPCListener = randomHost cfg.RawRESTListener = randomHost - cfg, err := server.SetupConfig(cfg) - req.NoError(err) + server.SetupConfig(cfg) srv, err := server.New(context.Background(), *cfg) t.Cleanup(func() { assert.NoError(t, srv.Close()) }) diff --git a/service/service.go b/service/service.go index c2c157c6..e65bb39a 100644 --- a/service/service.go +++ b/service/service.go @@ -82,13 +82,6 @@ func New( minMemoryLayer = totalLayers - options.cfg.MemoryLayers } - roundsDir := filepath.Join(datadir, "rounds") - if _, err := os.Stat(roundsDir); errors.Is(err, os.ErrNotExist) { - if err := os.Mkdir(roundsDir, 0o700); err != nil { - return nil, err - } - } - s := &Service{ genesis: genesis, cfg: options.cfg, @@ -222,13 +215,17 @@ func (s *Service) Run(ctx context.Context) error { func (s *Service) recover(ctx context.Context) (executing *round, err error) { roundsDir := filepath.Join(s.datadir, "rounds") - logger := logging.FromContext(ctx).Named("recovery") - logger.Info("recovering worker state", zap.String("datadir", s.datadir)) entries, err := os.ReadDir(roundsDir) - if err != nil { + switch { + case errors.Is(err, os.ErrNotExist): + return nil, nil + case err != nil: return nil, err } + logger := logging.FromContext(ctx).Named("recovery") + logger.Info("recovering worker state", zap.String("datadir", s.datadir)) + for _, entry := range entries { logger.Sugar().Infof("recovering entry %s", entry.Name()) if !entry.IsDir() { diff --git a/verifier/verifier.go b/verifier/verifier.go index 087c3537..2bc3eca6 100644 --- a/verifier/verifier.go +++ b/verifier/verifier.go @@ -2,10 +2,12 @@ package verifier import ( "bytes" + "errors" "fmt" "slices" "github.com/spacemeshos/merkle-tree" + "golang.org/x/exp/maps" "github.com/spacemeshos/poet/shared" ) @@ -22,50 +24,37 @@ func Validate(proof shared.MerkleProof, labelHashFunc func(data []byte) []byte, ) } - provenLeafIndices := asSortedSlice(shared.FiatShamir(proof.Root, numLeaves, securityParam)) - provenLeaves := make([][]byte, 0, len(proof.ProvenLeaves)) - for i := range proof.ProvenLeaves { - provenLeaves = append(provenLeaves, proof.ProvenLeaves[i][:]) - } - proofNodes := make([][]byte, 0, len(proof.ProofNodes)) - for i := range proof.ProofNodes { - proofNodes = append(proofNodes, proof.ProofNodes[i][:]) - } + provenLeafIndices := maps.Keys((shared.FiatShamir(proof.Root, numLeaves, securityParam))) + slices.Sort(provenLeafIndices) + valid, parkingSnapshots, err := merkle.ValidatePartialTreeWithParkingSnapshots( provenLeafIndices, - provenLeaves, - proofNodes, + proof.ProvenLeaves, + proof.ProofNodes, proof.Root, merkleHashFunc, ) if err != nil { - return fmt.Errorf("error while validating merkle proof: %v", err) + return fmt.Errorf("error while validating merkle proof: %w", err) } if !valid { - return fmt.Errorf("merkle proof not valid") + return errors.New("merkle proof not valid") } if len(parkingSnapshots) != len(proof.ProvenLeaves) { - return fmt.Errorf("merkle proof not valid") + return fmt.Errorf( + "merkle proof not valid: len(parkingSnapshots) != len(proof.ProvenLeaves) (%d != %d)", + len(parkingSnapshots), + len(proof.ProvenLeaves), + ) } makeLabel := shared.MakeLabelFunc() for id, label := range proof.ProvenLeaves { expectedLabel := makeLabel(labelHashFunc, provenLeafIndices[id], parkingSnapshots[id]) - if !bytes.Equal(expectedLabel, label[:]) { + if !bytes.Equal(expectedLabel, label) { return fmt.Errorf("label at index %d incorrect - expected: %x actual: %x", id, expectedLabel, label) } } return nil } - -func asSortedSlice(s map[uint64]bool) []uint64 { - ret := make([]uint64, 0, len(s)) - for key, value := range s { - if value { - ret = append(ret, key) - } - } - slices.Sort(ret) - return ret -} diff --git a/verifier/verifier_test.go b/verifier/verifier_test.go index df8e69c9..03e52ece 100644 --- a/verifier/verifier_test.go +++ b/verifier/verifier_test.go @@ -133,6 +133,5 @@ func TestValidateFailLabelValidation(t *testing.T) { ) r.NoError(err) err = Validate(*merkleProof, BadLabelHashFunc, hash.GenMerkleHashFunc(challenge), leafs, securityParam) - r.Error(err) - r.Regexp("label at index 0 incorrect - expected: [0-f]* actual: [0-f]*", err.Error()) + r.ErrorContains(err, "label at index 0 incorrect") }