Skip to content

Commit

Permalink
Merge pull request #72 from blendle/interrupt
Browse files Browse the repository at this point in the history
Add `streamutil.Interrupt()` function
  • Loading branch information
JeanMertz authored May 1, 2018
2 parents 24dbbbd + 0fde718 commit a5c6e94
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 17 deletions.
2 changes: 1 addition & 1 deletion streamclient/inmemclient/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func NewConsumer(options ...streamconfig.Option) (stream.Consumer, error) {
// This functionality is enabled by default, but can be disabled through a
// configuration flag.
if c.c.HandleInterrupt {
c.signals = make(chan os.Signal, 1)
c.signals = make(chan os.Signal, 3)
go streamutil.HandleInterrupts(c.signals, c.Close, c.logger)
}

Expand Down
2 changes: 1 addition & 1 deletion streamclient/inmemclient/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func NewProducer(options ...streamconfig.Option) (stream.Producer, error) {
// This functionality is enabled by default, but can be disabled through a
// configuration flag.
if p.c.HandleInterrupt {
p.signals = make(chan os.Signal, 1)
p.signals = make(chan os.Signal, 3)
go streamutil.HandleInterrupts(p.signals, p.Close, p.logger)
}

Expand Down
2 changes: 1 addition & 1 deletion streamclient/kafkaclient/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func NewConsumer(options ...streamconfig.Option) (stream.Consumer, error) {
// This functionality is enabled by default, but can be disabled through a
// configuration flag.
if c.c.HandleInterrupt {
c.signals = make(chan os.Signal, 1)
c.signals = make(chan os.Signal, 3)
go streamutil.HandleInterrupts(c.signals, c.Close, c.logger)
}

Expand Down
2 changes: 1 addition & 1 deletion streamclient/kafkaclient/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func NewProducer(options ...streamconfig.Option) (stream.Producer, error) {
// This functionality is enabled by default, but can be disabled through a
// configuration flag.
if p.c.HandleInterrupt {
p.signals = make(chan os.Signal, 1)
p.signals = make(chan os.Signal, 3)
go streamutil.HandleInterrupts(p.signals, p.Close, p.logger)
}

Expand Down
2 changes: 1 addition & 1 deletion streamclient/standardstreamclient/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func NewConsumer(options ...streamconfig.Option) (stream.Consumer, error) {
// This functionality is enabled by default, but can be disabled through a
// configuration flag.
if c.c.HandleInterrupt {
c.signals = make(chan os.Signal, 1)
c.signals = make(chan os.Signal, 3)
go streamutil.HandleInterrupts(c.signals, c.Close, c.logger)
}

Expand Down
2 changes: 1 addition & 1 deletion streamclient/standardstreamclient/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func NewProducer(options ...streamconfig.Option) (stream.Producer, error) {
// This functionality is enabled by default, but can be disabled through a
// configuration flag.
if p.c.HandleInterrupt {
p.signals = make(chan os.Signal, 1)
p.signals = make(chan os.Signal, 3)
go streamutil.HandleInterrupts(p.signals, p.Close, p.logger)
}

Expand Down
17 changes: 15 additions & 2 deletions streamutil/interrupt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,38 @@ package streamutil
import (
"os"
"os/signal"
"syscall"
"time"

"go.uber.org/zap"
)

// Interrupt returns a channel that receives a signal when the application
// receives either an SIGINT or SIGTERM signal. This is provided for convenience
// when dealing with a select statement and receiving stream messages, making it
// easy to cleanly exit after fully handling one message, but before handling
// the next message.
func Interrupt() <-chan os.Signal {
ch := make(chan os.Signal, 3)
signal.Notify(ch, os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT)

return ch
}

// HandleInterrupts monitors for an interrupt signal, and calls the provided
// closer function once received. It has a built-in timeout capability to force
// terminate the application when the closer takes too long to close, or returns
// an error during closing.
func HandleInterrupts(signals chan os.Signal, closer func() error, logger *zap.Logger) {
signal.Notify(signals, os.Interrupt)
signal.Notify(signals, os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT)

s, ok := <-signals
if !ok {
return
}

logger.Info(
"Got interrupt signal, cleaning up. Use ^C again to exit immediately.",
"Got interrupt signal, cleaning up. Use ^C to exit immediately.",
zap.String("signal", s.String()),
)

Expand Down
73 changes: 64 additions & 9 deletions streamutil/interrupt_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package streamutil_test

import (
"bufio"
"bytes"
"fmt"
"os"
"os/exec"
"syscall"
"testing"
"time"

Expand All @@ -12,6 +16,45 @@ import (
"go.uber.org/zap"
)

func TestInterrupt(t *testing.T) {
t.Parallel()

if os.Getenv("BE_TESTING_FATAL") == "1" {
select {
case s := <-streamutil.Interrupt():
println("interrupt received:", s.String())
return
case <-time.After(5 * time.Second):
os.Exit(1)
}

return
}

var tests = []os.Signal{
os.Interrupt,
syscall.SIGTERM,
syscall.SIGQUIT,
}

for _, tt := range tests {
t.Run(tt.String(), func(t *testing.T) {
cmd := exec.Command(os.Args[0], "-test.run="+t.Name())
cmd.Env = append(os.Environ(), "BE_TESTING_FATAL=1")

var b bytes.Buffer
cmd.Stderr = bufio.NewWriter(&b)
require.NoError(t, cmd.Start())
time.Sleep(150 * time.Millisecond)

require.NoError(t, cmd.Process.Signal(tt))
require.NoError(t, cmd.Wait())

assert.Contains(t, b.String(), "interrupt received: "+tt.String())
})
}
}

func TestHandleInterrupts(t *testing.T) {
t.Parallel()

Expand All @@ -26,19 +69,31 @@ func TestHandleInterrupts(t *testing.T) {
}

go streamutil.HandleInterrupts(ch, fn, logger)
ch <- os.Interrupt

time.Sleep(10 * time.Millisecond)
time.Sleep(5 * time.Second)
os.Exit(1)
}

return
var tests = []os.Signal{
os.Interrupt,
syscall.SIGTERM,
syscall.SIGQUIT,
}

cmd := exec.Command(os.Args[0], "-test.run="+t.Name())
cmd.Env = append(os.Environ(), "BE_TESTING_FATAL=1")
for _, tt := range tests {
t.Run(tt.String(), func(t *testing.T) {
cmd := exec.Command(os.Args[0], "-test.run="+t.Name())
cmd.Env = append(os.Environ(), "BE_TESTING_FATAL=1")

out, err := cmd.CombinedOutput()
require.Nil(t, err, "output received: %s", string(out))
var b bytes.Buffer
cmd.Stderr = bufio.NewWriter(&b)
require.NoError(t, cmd.Start())
time.Sleep(150 * time.Millisecond)

assert.Contains(t, string(out), "Got interrupt signal")
assert.Contains(t, string(out), "closed!")
require.NoError(t, cmd.Process.Signal(tt))
require.NoError(t, cmd.Wait())

assert.Contains(t, b.String(), fmt.Sprintf(`{"signal": "%s"}`, tt.String()))
})
}
}

0 comments on commit a5c6e94

Please sign in to comment.