diff --git a/service/round.go b/service/round.go index daabb713..d48bc710 100644 --- a/service/round.go +++ b/service/round.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io/fs" "os" "path/filepath" "strconv" @@ -102,6 +103,17 @@ func NewRound(datadir string, epoch uint, opts ...newRoundOptionFunc) (*round, e opt(r) } + err := r.loadState() + switch { + case errors.Is(err, fs.ErrNotExist): + // No state file, this is a new round. + if err := r.saveState(); err != nil { + return nil, err + } + case err != nil: + return nil, err + } + return r, nil } diff --git a/service/round_test.go b/service/round_test.go index 061e33db..f48d84c8 100644 --- a/service/round_test.go +++ b/service/round_test.go @@ -112,7 +112,6 @@ func TestRound_StateRecovery(t *testing.T) { recovered, err := NewRound(tmpdir, 0) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, recovered.Teardown(context.Background(), false)) }) - require.NoError(t, recovered.loadState()) // Verify require.False(t, recovered.IsFinished()) @@ -132,7 +131,6 @@ func TestRound_StateRecovery(t *testing.T) { recovered, err := NewRound(tmpdir, 0) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, recovered.Teardown(context.Background(), false)) }) - require.NoError(t, recovered.loadState()) // Verify require.False(t, recovered.IsFinished()) @@ -167,7 +165,6 @@ func TestRound_ExecutionRecovery(t *testing.T) { { round, err := NewRound(tmpdir, 1) req.NoError(err) - req.NoError(round.loadState()) ctx, stop := context.WithTimeout(context.Background(), 100*time.Millisecond) defer stop() @@ -179,7 +176,6 @@ func TestRound_ExecutionRecovery(t *testing.T) { { round, err := NewRound(tmpdir, 1) req.NoError(err) - req.NoError(round.loadState()) req.NoError(round.RecoverExecution(context.Background(), time.Now().Add(400*time.Millisecond), 0)) validateProof(t, round.execution) diff --git a/service/service.go b/service/service.go index f09b6a2b..c2c157c6 100644 --- a/service/service.go +++ b/service/service.go @@ -244,11 +244,6 @@ func (s *Service) recover(ctx context.Context) (executing *round, err error) { return nil, fmt.Errorf("failed to create round: %w", err) } - err = r.loadState() - if err != nil { - return nil, fmt.Errorf("invalid round state: %w", err) - } - logger.Info("recovered round", zap.Uint("epoch", r.epoch)) switch { @@ -261,9 +256,11 @@ func (s *Service) recover(ctx context.Context) (executing *round, err error) { ) s.onNewProof(ctx, r.epoch, r.execution) r.Teardown(ctx, true) - + case r.executionStarted.IsZero(): + // An open round from a previous poet version + logger.Info("round is open, removing it", zap.Uint("epoch", r.epoch)) + r.Teardown(ctx, true) default: - // Round is in executing state. logger.Info( "round is executing", zap.Time("started", r.executionStarted), diff --git a/service/service_test.go b/service/service_test.go index 6e820bd7..1f9bc578 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -9,11 +9,14 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "go.uber.org/zap/zaptest" "golang.org/x/sync/errgroup" "github.com/spacemeshos/poet/hash" + "github.com/spacemeshos/poet/logging" "github.com/spacemeshos/poet/server" "github.com/spacemeshos/poet/service" "github.com/spacemeshos/poet/service/mocks" @@ -222,7 +225,58 @@ func TestRecoverFinishedRound(t *testing.T) { req.Eventually(func() bool { _, err := os.Lstat(filepath.Join(datadir, "rounds", "9876")) return errors.Is(err, os.ErrNotExist) - }, time.Second, 10*time.Millisecond) + }, time.Second*5, 10*time.Millisecond) + + cancel() + req.NoError(eg.Wait()) +} + +func TestRemoveRecoveredOpenRound(t *testing.T) { + req := require.New(t) + ctx := logging.NewContext(context.Background(), zaptest.NewLogger(t)) + datadir := t.TempDir() + + // manually create a round and execute it + round, err := service.NewRound( + filepath.Join(datadir, "rounds"), + 1, + service.WithMembershipRoot([]byte{1, 2, 3, 4}), + ) + req.NoError(err) + t.Cleanup(func() { assert.NoError(t, round.Teardown(ctx, false)) }) + ctxE, cancel := context.WithTimeout(ctx, time.Millisecond*10) + defer cancel() + err = round.Execute(ctxE, time.Now().Add(time.Hour), 0, 0) + req.ErrorIs(err, context.DeadlineExceeded) + + // manually create an open round + openRound, err := service.NewRound( + filepath.Join(datadir, "rounds"), + 2, + service.WithMembershipRoot([]byte{1, 2, 3, 4}), + ) + req.NoError(err) + t.Cleanup(func() { assert.NoError(t, openRound.Teardown(ctx, true)) }) + + transport := transport.NewInMemory() + s, err := service.New( + ctx, + time.Now(), + datadir, + transport, + &server.RoundConfig{EpochDuration: time.Hour}, + ) + req.NoError(err) + + ctx, cancel = context.WithCancel(ctx) + defer cancel() + var eg errgroup.Group + eg.Go(func() error { return s.Run(ctx) }) + + req.Eventually(func() bool { + _, err := os.Lstat(filepath.Join(datadir, "rounds", "2")) + return errors.Is(err, os.ErrNotExist) + }, time.Second*5, 10*time.Millisecond) cancel() req.NoError(eg.Wait())