From d21b267edcabbc1747b1552faab818faad73a711 Mon Sep 17 00:00:00 2001 From: clyang82 Date: Wed, 18 Dec 2024 16:32:53 +0800 Subject: [PATCH] Recreate listener if error is occured Signed-off-by: clyang82 --- pkg/db/db_session/default.go | 36 ++++++++++------- pkg/db/db_session/test.go | 7 +++- pkg/db/session.go | 3 +- test/integration/db_listener_test.go | 58 ++++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 16 deletions(-) create mode 100644 test/integration/db_listener_test.go diff --git a/pkg/db/db_session/default.go b/pkg/db/db_session/default.go index 9c557a40..89aa3d2a 100755 --- a/pkg/db/db_session/default.go +++ b/pkg/db/db_session/default.go @@ -138,7 +138,7 @@ func (f *Default) DirectDB() *sql.DB { return f.db } -func waitForNotification(ctx context.Context, l *pq.Listener, callback func(id string)) { +func waitForNotification(ctx context.Context, l *pq.Listener, dbConfig *config.DatabaseConfig, channel string, callback func(id string)) { logger := ocmlogger.NewOCMLogger(ctx) for { select { @@ -149,21 +149,26 @@ func waitForNotification(ctx context.Context, l *pq.Listener, callback func(id s if n != nil { logger.V(4).Infof("Received event from channel [%s] : %s", n.Channel, n.Extra) callback(n.Extra) + } else { + // nil notification means the connection was closed + logger.Infof("recreate the listener due to the connection loss") + l.Close() + // recreate the listener + l = newListener(ctx, dbConfig, channel) } case <-time.After(10 * time.Second): logger.V(10).Infof("Received no events on channel during interval. Pinging source") - go func() { - // TODO: Need to handle the error, especially in cases of network failure. - err := l.Ping() - if err != nil { - logger.Error(err.Error()) - } - }() + if err := l.Ping(); err != nil { + logger.Infof("recreate the listener due to %s", err.Error()) + l.Close() + // recreate the listener + l = newListener(ctx, dbConfig, channel) + } } } } -func newListener(ctx context.Context, dbConfig *config.DatabaseConfig, channel string, callback func(id string)) { +func newListener(ctx context.Context, dbConfig *config.DatabaseConfig, channel string) *pq.Listener { logger := ocmlogger.NewOCMLogger(ctx) plog := func(ev pq.ListenerEventType, err error) { @@ -189,12 +194,17 @@ func newListener(ctx context.Context, dbConfig *config.DatabaseConfig, channel s panic(err) } - logger.Infof("Starting channeling monitor for %s", channel) - waitForNotification(ctx, listener, callback) + return listener } -func (f *Default) NewListener(ctx context.Context, channel string, callback func(id string)) { - newListener(ctx, f.config, channel, callback) +func (f *Default) NewListener(ctx context.Context, channel string, callback func(id string)) *pq.Listener { + logger := ocmlogger.NewOCMLogger(ctx) + + listener := newListener(ctx, f.config, channel) + + logger.Infof("Starting channeling monitor for %s", channel) + go waitForNotification(ctx, listener, f.config, channel, callback) + return listener } func (f *Default) New(ctx context.Context) *gorm.DB { diff --git a/pkg/db/db_session/test.go b/pkg/db/db_session/test.go index f21c3358..cfc76c11 100755 --- a/pkg/db/db_session/test.go +++ b/pkg/db/db_session/test.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + "github.com/lib/pq" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -215,6 +216,8 @@ func (f *Test) ResetDB() { f.wasDisconnected = true } -func (f *Test) NewListener(ctx context.Context, channel string, callback func(id string)) { - newListener(ctx, f.config, channel, callback) +func (f *Test) NewListener(ctx context.Context, channel string, callback func(id string)) *pq.Listener { + listener := newListener(ctx, f.config, channel) + go waitForNotification(ctx, listener, f.config, channel, callback) + return listener } diff --git a/pkg/db/session.go b/pkg/db/session.go index 10e4a5ad..2e818297 100755 --- a/pkg/db/session.go +++ b/pkg/db/session.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" + "github.com/lib/pq" "gorm.io/gorm" "github.com/openshift-online/maestro/pkg/config" @@ -16,5 +17,5 @@ type SessionFactory interface { CheckConnection() error Close() error ResetDB() - NewListener(ctx context.Context, channel string, callback func(id string)) + NewListener(ctx context.Context, channel string, callback func(id string)) *pq.Listener } diff --git a/test/integration/db_listener_test.go b/test/integration/db_listener_test.go new file mode 100644 index 00000000..e168c83d --- /dev/null +++ b/test/integration/db_listener_test.go @@ -0,0 +1,58 @@ +package integration + +import ( + "context" + "testing" + "time" + + "github.com/openshift-online/maestro/test" +) + +func TestWaitForNotification(t *testing.T) { + // it is used to check the result of the notification + result := make(chan string) + + h, _ := test.RegisterIntegration(t) + + account := h.NewRandAccount() + ctx, cancel := context.WithCancel(h.NewAuthenticatedContext(account)) + defer func() { + cancel() + }() + + g2 := h.Env().Database.SessionFactory.New(ctx) + + listener := h.Env().Database.SessionFactory.NewListener(ctx, "events", func(id string) { + result <- id + }) + var originalListenerId string + // find the original listener id in the pg_stat_activity table + g2.Raw("SELECT pid FROM pg_stat_activity WHERE query LIKE 'LISTEN%'").Scan(&originalListenerId) + if originalListenerId == "" { + t.Errorf("the original Listener was not recreated") + } + + // Simulate an errListenerClosed and wait for the listener to be recreated + listener.Close() + time.Sleep(2 * time.Second) + + var newListenerId string + g2.Raw("SELECT pid FROM pg_stat_activity WHERE query LIKE 'LISTEN%'").Scan(&newListenerId) + if newListenerId == "" { + t.Errorf("the new Listener was not created") + } + + // Validate the listener was recreated + if originalListenerId == newListenerId { + t.Errorf("Listener was not recreated") + } + // send a notification to the new listener + g2.Exec("NOTIFY events, 'test'") + + // wait for the notification to be received + time.Sleep(1 * time.Second) + + if <-result != "test" { + t.Errorf("the notification was not received") + } +}