Skip to content

Commit

Permalink
Recreate listener if error is occured
Browse files Browse the repository at this point in the history
Signed-off-by: clyang82 <[email protected]>
  • Loading branch information
clyang82 committed Jan 3, 2025
1 parent d9c5257 commit d21b267
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 16 deletions.
36 changes: 23 additions & 13 deletions pkg/db/db_session/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (f *Default) DirectDB() *sql.DB {
return f.db
}

func waitForNotification(ctx context.Context, l *pq.Listener, callback func(id string)) {
func waitForNotification(ctx context.Context, l *pq.Listener, dbConfig *config.DatabaseConfig, channel string, callback func(id string)) {
logger := ocmlogger.NewOCMLogger(ctx)
for {
select {
Expand All @@ -149,21 +149,26 @@ func waitForNotification(ctx context.Context, l *pq.Listener, callback func(id s
if n != nil {
logger.V(4).Infof("Received event from channel [%s] : %s", n.Channel, n.Extra)
callback(n.Extra)
} else {
// nil notification means the connection was closed
logger.Infof("recreate the listener due to the connection loss")
l.Close()
// recreate the listener
l = newListener(ctx, dbConfig, channel)
}
case <-time.After(10 * time.Second):
logger.V(10).Infof("Received no events on channel during interval. Pinging source")
go func() {
// TODO: Need to handle the error, especially in cases of network failure.
err := l.Ping()
if err != nil {
logger.Error(err.Error())
}
}()
if err := l.Ping(); err != nil {
logger.Infof("recreate the listener due to %s", err.Error())
l.Close()
// recreate the listener
l = newListener(ctx, dbConfig, channel)
}
}
}
}

func newListener(ctx context.Context, dbConfig *config.DatabaseConfig, channel string, callback func(id string)) {
func newListener(ctx context.Context, dbConfig *config.DatabaseConfig, channel string) *pq.Listener {
logger := ocmlogger.NewOCMLogger(ctx)

plog := func(ev pq.ListenerEventType, err error) {
Expand All @@ -189,12 +194,17 @@ func newListener(ctx context.Context, dbConfig *config.DatabaseConfig, channel s
panic(err)
}

logger.Infof("Starting channeling monitor for %s", channel)
waitForNotification(ctx, listener, callback)
return listener
}

func (f *Default) NewListener(ctx context.Context, channel string, callback func(id string)) {
newListener(ctx, f.config, channel, callback)
func (f *Default) NewListener(ctx context.Context, channel string, callback func(id string)) *pq.Listener {
logger := ocmlogger.NewOCMLogger(ctx)

listener := newListener(ctx, f.config, channel)

logger.Infof("Starting channeling monitor for %s", channel)
go waitForNotification(ctx, listener, f.config, channel, callback)
return listener
}

func (f *Default) New(ctx context.Context) *gorm.DB {
Expand Down
7 changes: 5 additions & 2 deletions pkg/db/db_session/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"time"

"github.com/lib/pq"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
Expand Down Expand Up @@ -215,6 +216,8 @@ func (f *Test) ResetDB() {
f.wasDisconnected = true
}

func (f *Test) NewListener(ctx context.Context, channel string, callback func(id string)) {
newListener(ctx, f.config, channel, callback)
func (f *Test) NewListener(ctx context.Context, channel string, callback func(id string)) *pq.Listener {
listener := newListener(ctx, f.config, channel)
go waitForNotification(ctx, listener, f.config, channel, callback)
return listener
}
3 changes: 2 additions & 1 deletion pkg/db/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"

"github.com/lib/pq"
"gorm.io/gorm"

"github.com/openshift-online/maestro/pkg/config"
Expand All @@ -16,5 +17,5 @@ type SessionFactory interface {
CheckConnection() error
Close() error
ResetDB()
NewListener(ctx context.Context, channel string, callback func(id string))
NewListener(ctx context.Context, channel string, callback func(id string)) *pq.Listener
}
58 changes: 58 additions & 0 deletions test/integration/db_listener_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package integration

import (
"context"
"testing"
"time"

"github.com/openshift-online/maestro/test"
)

func TestWaitForNotification(t *testing.T) {
// it is used to check the result of the notification
result := make(chan string)

h, _ := test.RegisterIntegration(t)

account := h.NewRandAccount()
ctx, cancel := context.WithCancel(h.NewAuthenticatedContext(account))
defer func() {
cancel()
}()

g2 := h.Env().Database.SessionFactory.New(ctx)

listener := h.Env().Database.SessionFactory.NewListener(ctx, "events", func(id string) {
result <- id
})
var originalListenerId string
// find the original listener id in the pg_stat_activity table
g2.Raw("SELECT pid FROM pg_stat_activity WHERE query LIKE 'LISTEN%'").Scan(&originalListenerId)
if originalListenerId == "" {
t.Errorf("the original Listener was not recreated")
}

// Simulate an errListenerClosed and wait for the listener to be recreated
listener.Close()
time.Sleep(2 * time.Second)

var newListenerId string
g2.Raw("SELECT pid FROM pg_stat_activity WHERE query LIKE 'LISTEN%'").Scan(&newListenerId)
if newListenerId == "" {
t.Errorf("the new Listener was not created")
}

// Validate the listener was recreated
if originalListenerId == newListenerId {
t.Errorf("Listener was not recreated")
}
// send a notification to the new listener
g2.Exec("NOTIFY events, 'test'")

// wait for the notification to be received
time.Sleep(1 * time.Second)

if <-result != "test" {
t.Errorf("the notification was not received")
}
}

0 comments on commit d21b267

Please sign in to comment.