diff --git a/internal/topo/node/source_pool.go b/internal/topo/node/source_pool.go index 8235e12f2b..7c16bffd4f 100644 --- a/internal/topo/node/source_pool.go +++ b/internal/topo/node/source_pool.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "sync" + "sync/atomic" "time" "github.com/lf-edge/ekuiper/internal/binder/io" @@ -58,6 +59,7 @@ func getSourceInstance(node *SourceNode, index int) (*sourceInstance, error) { } } // attach + s.count.Add(1) instanceKey := fmt.Sprintf("%s.%s.%d", rkey, node.ctx.GetRuleId(), index) err := s.attach(instanceKey, node.bufferLength) if err != nil { @@ -184,13 +186,13 @@ func (p *sourcePool) deleteInstance(k string, node *SourceNode, index int) { defer p.Unlock() s, ok := p.registry[k] if ok { - + s.count.Add(-1) instanceKey := fmt.Sprintf("%s.%s.%d", k, node.ctx.GetRuleId(), index) - end := s.detach(instanceKey) + s.detach(instanceKey) s.detachSchema(node.ctx.GetRuleId()) - if end { + if s.count.Load() == 0 { + s.ctx.GetLogger().Info("cancel source") s.cancel() - s.dataCh.Close() delete(p.registry, k) } @@ -208,8 +210,8 @@ type sourceInstance struct { type sourceSingleton struct { *sourceInstance // immutable cancel context.CancelFunc // immutable - - outputs map[string]*sourceInstanceChannels // read-write lock + count atomic.Int32 + outputs map[string]*sourceInstanceChannels // read-write lock sync.RWMutex } @@ -297,6 +299,7 @@ func (ss *sourceSingleton) attach(instanceKey string, bl int) error { err = func() error { ss.Lock() defer ss.Unlock() + ss.ctx.GetLogger().Infof("attaching instance") if _, ok := ss.outputs[instanceKey]; !ok { ss.outputs[instanceKey] = newSourceInstanceChannels(bl) } else { @@ -325,10 +328,11 @@ func (ss *sourceSingleton) detachSchema(ruleID string) { } } -// detach Detach an instance and return if the singleton is ended +// detach an instance and return if the singleton is ended func (ss *sourceSingleton) detach(instanceKey string) bool { ss.Lock() defer ss.Unlock() + ss.ctx.GetLogger().Infof("detach source instance %s", instanceKey) if chs, ok := ss.outputs[instanceKey]; ok { chs.dataCh.Close() } else { @@ -337,10 +341,6 @@ func (ss *sourceSingleton) detach(instanceKey string) bool { return false } delete(ss.outputs, instanceKey) - if len(ss.outputs) == 0 { - ss.cancel() - return true - } return false } diff --git a/internal/topo/node/source_pool_test.go b/internal/topo/node/source_pool_test.go index 2f24d34fdb..4ae9a0189e 100644 --- a/internal/topo/node/source_pool_test.go +++ b/internal/topo/node/source_pool_test.go @@ -16,6 +16,9 @@ package node import ( "testing" + "time" + + "github.com/stretchr/testify/assert" "github.com/lf-edge/ekuiper/internal/conf" "github.com/lf-edge/ekuiper/internal/topo/context" @@ -102,3 +105,63 @@ func TestSourcePool(t *testing.T) { removeSourceInstance(n2) } + +func TestSourcePoolRecreate(t *testing.T) { + n := NewSourceNode("test", ast.TypeStream, nil, &ast.Options{ + DATASOURCE: "demo1", + TYPE: "mock", + SHARED: true, + }, &api.RuleOption{SendError: false}, false, false, nil) + contextLogger := conf.Log.WithField("rule", "mockRule0") + ctx := context.WithValue(context.Background(), context.LoggerKey, contextLogger) + tempStore, _ := state.CreateStore("mockRule0", api.AtMostOnce) + n.ctx = ctx.WithMeta("mockRule0", "test", tempStore) + + // Test add source instance + _, err := getSourceInstance(n, 0) + assert.NoError(t, err) + time.Sleep(10 * time.Millisecond) + go func() { + removeSourceInstance(n) + }() + _, err = getSourceInstance(n, 0) + assert.NoError(t, err) + + poolLen := len(pool.registry) + if poolLen != 1 { + t.Errorf("source instances length unmatch: expect %d but got %d", 1, poolLen) + return + } + si, ok := pool.registry["mock.test"] + if !ok { + t.Errorf("source instances pool unmatch: can't find key %s", "mock.test") + return + } + outputLen := len(si.outputs) + if outputLen != 1 { + t.Errorf("source instances length unmatch: expect %d but got %d", 3, outputLen) + return + } + time.Sleep(1 * time.Second) + // Test add source instance + removeSourceInstance(n) + _, err = getSourceInstance(n, 0) + assert.NoError(t, err) + + poolLen = len(pool.registry) + if poolLen != 1 { + t.Errorf("source instances length unmatch: expect %d but got %d", 1, poolLen) + return + } + si, ok = pool.registry["mock.test"] + if !ok { + t.Errorf("source instances pool unmatch: can't find key %s", "mock.test") + return + } + outputLen = len(si.outputs) + if outputLen != 1 { + t.Errorf("source instances length unmatch: expect %d but got %d", 3, outputLen) + return + } + time.Sleep(1 * time.Second) +}