Skip to content

Commit

Permalink
Refactor source consumption (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma authored Jun 8, 2018
1 parent 45049dd commit dd297e3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 33 deletions.
4 changes: 3 additions & 1 deletion kafka/source.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package kafka

import (
"time"

"github.com/Shopify/sarama"
"github.com/bsm/sarama-cluster"
"github.com/pkg/errors"
Expand Down Expand Up @@ -112,7 +114,7 @@ func (s *Source) Consume() (key, value interface{}, err error) {

return k, v, nil

default:
case <-time.After(100 * time.Millisecond):
return nil, nil, nil
}
}
Expand Down
75 changes: 48 additions & 27 deletions task.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ package streams

import (
"sync"
"time"

"github.com/msales/pkg/log"
"github.com/msales/pkg/stats"
)

type record struct {
Node Node
Key interface{}
Value interface{}
}

type ErrorFunc func(error)

type TaskFunc func(*streamTask)
Expand Down Expand Up @@ -37,9 +42,11 @@ type streamTask struct {
logger log.Logger
stats stats.Stats

running bool
errorFn ErrorFunc
wg sync.WaitGroup
running bool
errorFn ErrorFunc
records chan record
runWg sync.WaitGroup
sourceWg sync.WaitGroup
}

func NewTask(topology *Topology, opts ...TaskFunc) Task {
Expand All @@ -63,35 +70,21 @@ func (t *streamTask) run() {
return
}

t.records = make(chan record, 1000)
t.running = true
t.wg.Add(1)
defer t.wg.Done()

ctx := NewProcessorContext(t, t.logger, t.stats)
t.setupTopology(ctx)

for t.running == true {
var empty = 0;
for source, node := range t.topology.Sources() {
k, v, err := source.Consume()
if err != nil {
t.handleError(err)
}
t.consumeSources()

if k == nil && v == nil {
empty++
continue
}
t.runWg.Add(1)
defer t.runWg.Done()

ctx.currentNode = node
if err := node.Process(k, v); err != nil {
t.handleError(err)
}
}

// All the sources where empty, wait a short while
if empty == len(t.topology.Sources()) {
time.Sleep(10 * time.Millisecond)
for r := range t.records {
ctx.currentNode = r.Node
if err := r.Node.Process(r.Key, r.Value); err != nil {
t.handleError(err)
}
}
}
Expand Down Expand Up @@ -124,6 +117,32 @@ func (t *streamTask) handleError(err error) {
t.errorFn(err)
}

func (t *streamTask) consumeSources() {
for source, node := range t.topology.Sources() {
go func(source Source, node Node) {
t.sourceWg.Add(1)
defer t.sourceWg.Done()

for t.running {
k, v, err := source.Consume()
if err != nil {
t.handleError(err)
}

if k == nil && v == nil {
continue
}

t.records <- record{
Node: node,
Key: k,
Value: v,
}
}
}(source, node)
}
}

func (t *streamTask) Start() {
go t.run()
}
Expand All @@ -144,8 +163,10 @@ func (t *streamTask) OnError(fn ErrorFunc) {

func (t *streamTask) Close() error {
t.running = false
t.sourceWg.Wait()

t.wg.Wait()
close(t.records)
t.runWg.Wait()

return t.closeTopology()
}
4 changes: 0 additions & 4 deletions topology.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ func (n *SourceNode) Children() []Node {
}

func (n *SourceNode) Process(key, value interface{}) error {
if key == nil && value == nil {
return nil
}

n.ctx.Stats().Inc("node.throughput", 1, 1.0, map[string]string{"name": n.name})

return n.ctx.Forward(key, value)
Expand Down
1 change: 0 additions & 1 deletion topology_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ func TestSourceNode_Process(t *testing.T) {
n := SourceNode{}
n.WithContext(ctx)

n.Process(nil, nil)
n.Process(key, value)

ctx.AssertExpectations()
Expand Down

0 comments on commit dd297e3

Please sign in to comment.