From 8b789aed346647484e36a1edf8ce42d0f20f1e03 Mon Sep 17 00:00:00 2001 From: Juho Makinen Date: Fri, 10 Jan 2025 14:13:59 +1100 Subject: [PATCH] chore: window function to be a proper window function, stateIter to return always the current state as the first element --- backend/controller/controller.go | 2 +- backend/controller/state/eventextractor.go | 11 ++-- .../controller/state/eventextractor_test.go | 3 +- internal/channels/itercontext_test.go | 45 ++++++++++++++++ internal/iterops/changes.go | 6 +-- internal/iterops/concat.go | 15 ++++++ internal/iterops/const.go | 13 +++++ internal/iterops/dedup.go | 20 +++++++ internal/iterops/empty.go | 8 +++ internal/iterops/interops_test.go | 52 +++++++++++++++++-- internal/iterops/map.go | 4 +- internal/iterops/window.go | 15 +++--- internal/raft/cluster.go | 13 ++++- internal/raft/cluster_test.go | 8 +-- internal/statemachine/handle.go | 33 +++++++----- internal/statemachine/handle_test.go | 32 ++++++++---- 16 files changed, 233 insertions(+), 47 deletions(-) create mode 100644 internal/channels/itercontext_test.go create mode 100644 internal/iterops/concat.go create mode 100644 internal/iterops/const.go create mode 100644 internal/iterops/dedup.go create mode 100644 internal/iterops/empty.go diff --git a/backend/controller/controller.go b/backend/controller/controller.go index e6687afc90..06aa02a781 100644 --- a/backend/controller/controller.go +++ b/backend/controller/controller.go @@ -1063,7 +1063,7 @@ func (s *Service) watchModuleChanges(ctx context.Context, sendChange func(respon } logger.Tracef("Seeded %d deployments", initialCount) - for notification := range iterops.Changes(stateIter, view, state.EventExtractor) { + for notification := range iterops.Changes(stateIter, state.EventExtractor) { switch event := notification.(type) { case *state.DeploymentCreatedEvent: err := sendChange(&ftlv1.PullSchemaResponse{ //nolint:forcetypeassert diff --git a/backend/controller/state/eventextractor.go b/backend/controller/state/eventextractor.go index 7280019de5..c77d536389 100644 --- a/backend/controller/state/eventextractor.go +++ b/backend/controller/state/eventextractor.go @@ -1,9 +1,14 @@ package state -import "github.com/alecthomas/types/tuple" +import ( + "iter" + + "github.com/alecthomas/types/tuple" + "github.com/block/ftl/internal/iterops" +) // EventExtractor calculates controller events from changes to the state. -func EventExtractor(diff tuple.Pair[SchemaState, SchemaState]) []SchemaEvent { +func EventExtractor(diff tuple.Pair[SchemaState, SchemaState]) iter.Seq[SchemaEvent] { var events []SchemaEvent previous := diff.A @@ -43,5 +48,5 @@ func EventExtractor(diff tuple.Pair[SchemaState, SchemaState]) []SchemaEvent { }) } } - return events + return iterops.Const(events...) } diff --git a/backend/controller/state/eventextractor_test.go b/backend/controller/state/eventextractor_test.go index 1417d74e9f..796bd0a6be 100644 --- a/backend/controller/state/eventextractor_test.go +++ b/backend/controller/state/eventextractor_test.go @@ -1,6 +1,7 @@ package state import ( + "slices" "testing" "time" @@ -127,7 +128,7 @@ func TestEventExtractor(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := EventExtractor(tuple.PairOf(tt.previous, tt.current)) + got := slices.Collect(EventExtractor(tuple.PairOf(tt.previous, tt.current))) assert.Equal(t, tt.want, got) }) } diff --git a/internal/channels/itercontext_test.go b/internal/channels/itercontext_test.go new file mode 100644 index 0000000000..dbe0e3569f --- /dev/null +++ b/internal/channels/itercontext_test.go @@ -0,0 +1,45 @@ +package channels + +import ( + "context" + "slices" + "testing" + "time" + + "github.com/alecthomas/assert/v2" +) + +func TestIterContext(t *testing.T) { + t.Run("iterates until channel closed", func(t *testing.T) { + ch := make(chan int) + ctx := context.Background() + + // Start goroutine to send values + go func() { + ch <- 1 + ch <- 2 + ch <- 3 + close(ch) + }() + + assert.Equal(t, []int{1, 2, 3}, slices.Collect(IterContext(ctx, ch))) + }) + + t.Run("stops when context cancelled", func(t *testing.T) { + ch := make(chan int) + ctx, cancel := context.WithCancel(context.Background()) + + // Start goroutine to send values + go func() { + ch <- 1 + ch <- 2 + time.Sleep(10 * time.Millisecond) // Small delay to ensure cancel happens + cancel() // Cancel context before sending 3 + ch <- 3 // This should not be received + close(ch) + }() + + assert.Equal(t, []int{1, 2}, slices.Collect(IterContext(ctx, ch))) + assert.Error(t, ctx.Err()) + }) +} diff --git a/internal/iterops/changes.go b/internal/iterops/changes.go index a8a1b41226..c3f05de227 100644 --- a/internal/iterops/changes.go +++ b/internal/iterops/changes.go @@ -7,9 +7,9 @@ import ( ) // ChangeExtractor extracts changes from an old and new state. -type ChangeExtractor[S, C any] func(tuple.Pair[S, S]) []C +type ChangeExtractor[S, C any] func(tuple.Pair[S, S]) iter.Seq[C] // Changes returns a stream of change events from a stream of evolving state. -func Changes[S, C any](in iter.Seq[S], start S, extractor ChangeExtractor[S, C]) iter.Seq[C] { - return FlatMap(WindowPair(in, start), extractor) +func Changes[S, C any](in iter.Seq[S], extractor ChangeExtractor[S, C]) iter.Seq[C] { + return FlatMap(WindowPair(in), extractor) } diff --git a/internal/iterops/concat.go b/internal/iterops/concat.go new file mode 100644 index 0000000000..f6141695c0 --- /dev/null +++ b/internal/iterops/concat.go @@ -0,0 +1,15 @@ +package iterops + +import "iter" + +func Concat[T any](in ...iter.Seq[T]) iter.Seq[T] { + return func(yield func(T) bool) { + for _, n := range in { + for m := range n { + if !yield(m) { + return + } + } + } + } +} diff --git a/internal/iterops/const.go b/internal/iterops/const.go new file mode 100644 index 0000000000..9d072c30c6 --- /dev/null +++ b/internal/iterops/const.go @@ -0,0 +1,13 @@ +package iterops + +import "iter" + +func Const[T any](in ...T) iter.Seq[T] { + return func(yield func(T) bool) { + for _, n := range in { + if !yield(n) { + return + } + } + } +} diff --git a/internal/iterops/dedup.go b/internal/iterops/dedup.go new file mode 100644 index 0000000000..611e63244e --- /dev/null +++ b/internal/iterops/dedup.go @@ -0,0 +1,20 @@ +package iterops + +import ( + "iter" + "reflect" +) + +// Dedup returns an iterator that yields values from the input iterator, removing consecutive duplicates. +func Dedup[T any](seq iter.Seq[T]) iter.Seq[T] { + return func(yield func(T) bool) { + var last T + seq(func(v T) bool { + if reflect.DeepEqual(v, last) { + return true + } + last = v + return yield(v) + }) + } +} diff --git a/internal/iterops/empty.go b/internal/iterops/empty.go new file mode 100644 index 0000000000..fc2f132fd0 --- /dev/null +++ b/internal/iterops/empty.go @@ -0,0 +1,8 @@ +package iterops + +import "iter" + +// Empty returns an empty iterator. +func Empty[T any]() iter.Seq[T] { + return func(yield func(T) bool) {} +} diff --git a/internal/iterops/interops_test.go b/internal/iterops/interops_test.go index 1829f0becc..b11e0c5de6 100644 --- a/internal/iterops/interops_test.go +++ b/internal/iterops/interops_test.go @@ -1,6 +1,7 @@ package iterops_test import ( + "iter" "slices" "testing" @@ -11,9 +12,8 @@ import ( func TestWindowPair(t *testing.T) { input := slices.Values([]int{1, 2, 3, 4}) - result := slices.Collect(iterops.WindowPair(input, 0)) + result := slices.Collect(iterops.WindowPair(input)) assert.Equal(t, result, []tuple.Pair[int, int]{ - tuple.PairOf(0, 1), tuple.PairOf(1, 2), tuple.PairOf(2, 3), tuple.PairOf(3, 4), @@ -28,6 +28,52 @@ func TestMap(t *testing.T) { func TestFlatMap(t *testing.T) { input := slices.Values([]int{1, 2, 3, 4}) - result := slices.Collect(iterops.FlatMap(input, func(v int) []int { return []int{v, v * 2} })) + result := slices.Collect(iterops.FlatMap(input, func(v int) iter.Seq[int] { return iterops.Const(v, v*2) })) assert.Equal(t, result, []int{1, 2, 2, 4, 3, 6, 4, 8}) } + +func TestConcat(t *testing.T) { + input := slices.Values([]int{1, 2, 3, 4}) + result := slices.Collect(iterops.Concat(input, input)) + assert.Equal(t, result, []int{1, 2, 3, 4, 1, 2, 3, 4}) + + result = slices.Collect(iterops.Concat( + iterops.Const(1), + iterops.Const(2), + iterops.Const(3), + iterops.Const(4), + )) + assert.Equal(t, result, []int{1, 2, 3, 4}) +} + +func TestConst(t *testing.T) { + input := 1 + result := slices.Collect(iterops.Const(input)) + assert.Equal(t, result, []int{1}) +} + +func TestEmpty(t *testing.T) { + result := slices.Collect(iterops.Empty[int]()) + assert.Equal(t, result, nil) + + assert.Equal(t, slices.Collect(iterops.Concat( + iterops.Empty[int](), + iterops.Empty[int](), + )), nil) + + assert.Equal(t, slices.Collect(iterops.Concat( + iterops.Const(1), + iterops.Empty[int](), + )), []int{1}) + + assert.Equal(t, slices.Collect(iterops.Concat( + iterops.Empty[int](), + iterops.Const(1), + )), []int{1}) +} + +func TestDedup(t *testing.T) { + input := slices.Values([]int{1, 2, 2, 3, 3, 4, 1}) + result := slices.Collect(iterops.Dedup(input)) + assert.Equal(t, result, []int{1, 2, 3, 4, 1}) +} diff --git a/internal/iterops/map.go b/internal/iterops/map.go index 26c159c362..eb74ae6e84 100644 --- a/internal/iterops/map.go +++ b/internal/iterops/map.go @@ -12,10 +12,10 @@ func Map[T any, U any](in iter.Seq[T], fn func(T) U) iter.Seq[U] { } } -func FlatMap[T any, U any](in iter.Seq[T], fn func(T) []U) iter.Seq[U] { +func FlatMap[T any, U any](in iter.Seq[T], fn func(T) iter.Seq[U]) iter.Seq[U] { return func(yield func(U) bool) { for n := range in { - for _, u := range fn(n) { + for u := range fn(n) { if !yield(u) { return } diff --git a/internal/iterops/window.go b/internal/iterops/window.go index f5c6e05248..7b355c3408 100644 --- a/internal/iterops/window.go +++ b/internal/iterops/window.go @@ -3,19 +3,22 @@ package iterops import ( "iter" + "github.com/alecthomas/types/optional" "github.com/alecthomas/types/tuple" ) // WindowPair returns a window of size 2 of the input iterator. -func WindowPair[T any](in iter.Seq[T], start T) iter.Seq[tuple.Pair[T, T]] { +func WindowPair[T any](in iter.Seq[T]) iter.Seq[tuple.Pair[T, T]] { return func(yield func(tuple.Pair[T, T]) bool) { - previous := start + previous := optional.None[T]() for n := range in { - result := tuple.PairOf(previous, n) - previous = n - if !yield(result) { - return + if val, ok := previous.Get(); ok { + result := tuple.PairOf(val, n) + if !yield(result) { + return + } } + previous = optional.Some(n) } } } diff --git a/internal/raft/cluster.go b/internal/raft/cluster.go index 3b70d6c287..efeb4217f3 100644 --- a/internal/raft/cluster.go +++ b/internal/raft/cluster.go @@ -22,6 +22,7 @@ import ( raftpbconnect "github.com/block/ftl/backend/protos/xyz/block/ftl/raft/v1/raftpbconnect" ftlv1 "github.com/block/ftl/backend/protos/xyz/block/ftl/v1" "github.com/block/ftl/internal/channels" + "github.com/block/ftl/internal/iterops" "github.com/block/ftl/internal/log" "github.com/block/ftl/internal/retry" "github.com/block/ftl/internal/rpc" @@ -208,9 +209,16 @@ func (s *ShardHandle[Q, R, E]) StateIter(ctx context.Context, query Q) (iter.Seq panic("cluster not started") } - result := make(chan R) + result := make(chan R, 64) logger := log.FromContext(ctx).Scope("raft") + previous, err := s.Query(ctx, query) + if err != nil { + return nil, err + } + + result <- previous + // get the last known index as the starting point last, err := s.getLastIndex() if err != nil { @@ -253,7 +261,8 @@ func (s *ShardHandle[Q, R, E]) StateIter(ctx context.Context, query Q) (iter.Seq } }() - return channels.IterContext(ctx, result), nil + // dedup, as we might get false positives due to index changes caused by membership changes + return iterops.Dedup(channels.IterContext(ctx, result)), nil } func (s *ShardHandle[Q, R, E]) getLastIndex() (uint64, error) { diff --git a/internal/raft/cluster_test.go b/internal/raft/cluster_test.go index a4d3114d84..c0d40e2883 100644 --- a/internal/raft/cluster_test.go +++ b/internal/raft/cluster_test.go @@ -142,7 +142,7 @@ func TestLeavingCluster(t *testing.T) { assertShardValue(ctx, t, 2, shards[1:]...) } -func TestChanges(t *testing.T) { +func TestStateIter(t *testing.T) { ctx := testContext(t) _, shards := startClusters(ctx, t, 2, func(b *raft.Builder) sm.Handle[int64, int64, IntEvent] { @@ -156,9 +156,11 @@ func TestChanges(t *testing.T) { assert.NoError(t, shards[1].Publish(ctx, IntEvent(1))) next, _ := iter.Pull(changes) - _, _ = next() v, _ := next() - assert.Equal(t, v, 2) + t.Logf("v1: %v", v) + + v, _ = next() + assert.Equal(t, 2, v) } func testBuilder(t *testing.T, addresses []*net.TCPAddr, id uint64, address string, controlBind *url.URL) *raft.Builder { diff --git a/internal/statemachine/handle.go b/internal/statemachine/handle.go index 2a3666cfe7..180b364219 100644 --- a/internal/statemachine/handle.go +++ b/internal/statemachine/handle.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "iter" - "reflect" "github.com/block/ftl/internal/channels" "github.com/block/ftl/internal/iterops" @@ -23,7 +22,12 @@ type Handle[Q any, R any, E any] interface { // Query retrieves the current state of the state machine. Query(ctx context.Context, query Q) (R, error) - // StateIter returns a stream of state based on a query. + // StateIter returns an iterator of state based on a query. + // + // The current state is returned as the first element of the iterator, + // followed by a stream of states for each change. + // + // The iterator is finished when the context is cancelled. StateIter(ctx context.Context, query Q) (iter.Seq[R], error) } @@ -69,18 +73,17 @@ func (l *localHandle[Q, R, E]) StateIter(ctx context.Context, query Q) (iter.Seq return nil, err } - return iterops.FlatMap(channels.IterContext(ctx, subs), func(struct{}) []R { - r, err := l.Query(ctx, query) - if err != nil { - logger.Warnf("query for changes failed: %s", err) - return nil - } - if reflect.DeepEqual(previous, r) { - return nil - } - previous = r - return []R{r} - }), nil + return iterops.Concat( + iterops.Const(previous), + iterops.FlatMap(channels.IterContext(ctx, subs), func(struct{}) iter.Seq[R] { + r, err := l.Query(ctx, query) + if err != nil { + logger.Warnf("query for changes failed: %s", err) + return iterops.Empty[R]() + } + return iterops.Const(r) + }), + ), nil } // SingleQueryHandle is a handle to a state machine that only supports a single query. @@ -117,6 +120,8 @@ func (h *SingleQueryHandle[Q, R, E]) Publish(ctx context.Context, msg E) error { } // StateIter returns a stream of state machine based on a query. +// The current state is returned as the first element of the iterator, +// followed by a stream of states for each change. // // The iterator is finished when the context is cancelled. func (h *SingleQueryHandle[Q, R, E]) StateIter(ctx context.Context) (iter.Seq[R], error) { diff --git a/internal/statemachine/handle_test.go b/internal/statemachine/handle_test.go index bd10ebb957..6a8b879172 100644 --- a/internal/statemachine/handle_test.go +++ b/internal/statemachine/handle_test.go @@ -2,6 +2,7 @@ package statemachine import ( "context" + "iter" "sync" "testing" "time" @@ -38,17 +39,19 @@ func (m *mockStateMachine) Subscribe(ctx context.Context) (<-chan struct{}, erro func (m *mockStateMachine) Publish(msg string) error { m.mu.Lock() + defer m.mu.Unlock() + m.value = msg m.updates = append(m.updates, msg) - m.mu.Unlock() - m.notifier.Notify(m.runningCtx) + return nil } func (m *mockStateMachine) Lookup(_ string) (string, error) { m.mu.Lock() defer m.mu.Unlock() + if m.queryErr != nil { return "", m.queryErr } @@ -74,7 +77,7 @@ func TestLocalHandle(t *testing.T) { assert.Equal(t, "new value", result) }) - t.Run("changes channel", func(t *testing.T) { + t.Run("StateIter", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -90,12 +93,23 @@ func TestLocalHandle(t *testing.T) { assert.NoError(t, mock.Publish("updated value")) - for newValue := range changes1 { - assert.Equal(t, "updated value", newValue) - } + pull1, _ := iter.Pull[string](changes1) + pull2, _ := iter.Pull[string](changes2) + + v1, ok := pull1() + assert.True(t, ok) + assert.Equal(t, "initial", v1) + + v2, ok := pull2() + assert.True(t, ok) + assert.Equal(t, "initial", v2) + + v1, ok = pull1() + assert.True(t, ok) + assert.Equal(t, "updated value", v1) - for newValue := range changes2 { - assert.Equal(t, "updated value", newValue) - } + v2, ok = pull2() + assert.True(t, ok) + assert.Equal(t, "updated value", v2) }) }