diff --git a/kafka/source.go b/kafka/source.go index d7d9924..ef480e6 100644 --- a/kafka/source.go +++ b/kafka/source.go @@ -1,6 +1,8 @@ package kafka import ( + "time" + "github.com/Shopify/sarama" "github.com/bsm/sarama-cluster" "github.com/pkg/errors" @@ -112,7 +114,7 @@ func (s *Source) Consume() (key, value interface{}, err error) { return k, v, nil - default: + case <-time.After(100 * time.Millisecond): return nil, nil, nil } } diff --git a/task.go b/task.go index 64f58a0..f498ca5 100644 --- a/task.go +++ b/task.go @@ -2,12 +2,17 @@ package streams import ( "sync" - "time" "github.com/msales/pkg/log" "github.com/msales/pkg/stats" ) +type record struct { + Node Node + Key interface{} + Value interface{} +} + type ErrorFunc func(error) type TaskFunc func(*streamTask) @@ -37,9 +42,11 @@ type streamTask struct { logger log.Logger stats stats.Stats - running bool - errorFn ErrorFunc - wg sync.WaitGroup + running bool + errorFn ErrorFunc + records chan record + runWg sync.WaitGroup + sourceWg sync.WaitGroup } func NewTask(topology *Topology, opts ...TaskFunc) Task { @@ -63,35 +70,21 @@ func (t *streamTask) run() { return } + t.records = make(chan record, 1000) t.running = true - t.wg.Add(1) - defer t.wg.Done() ctx := NewProcessorContext(t, t.logger, t.stats) t.setupTopology(ctx) - for t.running == true { - var empty = 0; - for source, node := range t.topology.Sources() { - k, v, err := source.Consume() - if err != nil { - t.handleError(err) - } + t.consumeSources() - if k == nil && v == nil { - empty++ - continue - } + t.runWg.Add(1) + defer t.runWg.Done() - ctx.currentNode = node - if err := node.Process(k, v); err != nil { - t.handleError(err) - } - } - - // All the sources where empty, wait a short while - if empty == len(t.topology.Sources()) { - time.Sleep(10 * time.Millisecond) + for r := range t.records { + ctx.currentNode = r.Node + if err := r.Node.Process(r.Key, r.Value); err != nil { + t.handleError(err) } } } @@ -124,6 +117,32 @@ func (t *streamTask) handleError(err error) { t.errorFn(err) } +func (t *streamTask) consumeSources() { + for source, node := range t.topology.Sources() { + go func(source Source, node Node) { + t.sourceWg.Add(1) + defer t.sourceWg.Done() + + for t.running { + k, v, err := source.Consume() + if err != nil { + t.handleError(err) + } + + if k == nil && v == nil { + continue + } + + t.records <- record{ + Node: node, + Key: k, + Value: v, + } + } + }(source, node) + } +} + func (t *streamTask) Start() { go t.run() } @@ -144,8 +163,10 @@ func (t *streamTask) OnError(fn ErrorFunc) { func (t *streamTask) Close() error { t.running = false + t.sourceWg.Wait() - t.wg.Wait() + close(t.records) + t.runWg.Wait() return t.closeTopology() } diff --git a/topology.go b/topology.go index 7959e0d..522915f 100644 --- a/topology.go +++ b/topology.go @@ -30,10 +30,6 @@ func (n *SourceNode) Children() []Node { } func (n *SourceNode) Process(key, value interface{}) error { - if key == nil && value == nil { - return nil - } - n.ctx.Stats().Inc("node.throughput", 1, 1.0, map[string]string{"name": n.name}) return n.ctx.Forward(key, value) diff --git a/topology_test.go b/topology_test.go index 864c703..3f33d43 100644 --- a/topology_test.go +++ b/topology_test.go @@ -48,7 +48,6 @@ func TestSourceNode_Process(t *testing.T) { n := SourceNode{} n.WithContext(ctx) - n.Process(nil, nil) n.Process(key, value) ctx.AssertExpectations()