From a725db901477329b4f3d5f872a65a632ab915b25 Mon Sep 17 00:00:00 2001 From: Koen Bollen Date: Thu, 7 Sep 2023 14:57:48 +0200 Subject: [PATCH] Timeout peers when the server is shutdown. (#55) When our backend needs to restart peers are disconnected can could get into the timeouts table, therefor they could reconnect. This PR will make sure all connections are closed when the server shuts down allowing peers to be timed out and reconnect later. --- cmd/signaling/main.go | 22 +++++++++++++++------- internal/signaling.go | 11 ++++++++--- internal/signaling/handler.go | 9 +++++++-- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/cmd/signaling/main.go b/cmd/signaling/main.go index 261a1a9..4afdad2 100644 --- a/cmd/signaling/main.go +++ b/cmd/signaling/main.go @@ -3,6 +3,7 @@ package main import ( "context" "math/rand" + "net" "net/http" "os" "os/signal" @@ -49,7 +50,7 @@ func main() { ) go credentialsClient.Run(ctx) - mux := internal.Signaling(ctx, store, credentialsClient) + mux, cleanup := internal.Signaling(ctx, store, credentialsClient) cors := cors.Default() handler := logging.Middleware(cors.Handler(mux), logger) @@ -64,6 +65,10 @@ func main() { Addr: addr, Handler: handler, + BaseContext: func(net.Listener) context.Context { + return ctx + }, + ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, IdleTimeout: 650 * time.Second, @@ -77,13 +82,16 @@ func main() { logger.Info("listening", zap.String("addr", addr)) <-ctx.Done() - if flushed != nil { - <-flushed - } + logger.Info("shutting down") - ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if err := server.Shutdown(ctx); err != nil { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + if err := server.Shutdown(shutdownCtx); err != nil { logger.Fatal("failed to shutdown server", zap.Error(err)) } + + cleanup() + if flushed != nil { + <-flushed + } } diff --git a/internal/signaling.go b/internal/signaling.go index c8614be..dfa8f6a 100644 --- a/internal/signaling.go +++ b/internal/signaling.go @@ -11,10 +11,15 @@ import ( "github.com/poki/netlib/internal/util" ) -func Signaling(ctx context.Context, store stores.Store, credentialsClient *cloudflare.CredentialsClient) http.Handler { +func Signaling(ctx context.Context, store stores.Store, credentialsClient *cloudflare.CredentialsClient) (http.Handler, func()) { mux := http.NewServeMux() - mux.Handle("/v0/signaling", signaling.Handler(ctx, store, credentialsClient)) + openConnections, signaling := signaling.Handler(ctx, store, credentialsClient) + + cleanup := func() { + openConnections.Wait() + } + mux.Handle("/v0/signaling", signaling) hasCredentials := uint32(0) mux.HandleFunc("/ready", func(w http.ResponseWriter, r *http.Request) { @@ -43,5 +48,5 @@ func Signaling(ctx context.Context, store stores.Store, credentialsClient *cloud mux.HandleFunc("/health", healthCheck) mux.HandleFunc("/", healthCheck) - return mux + return mux, cleanup } diff --git a/internal/signaling/handler.go b/internal/signaling/handler.go index 3ac63aa..d3ca35f 100644 --- a/internal/signaling/handler.go +++ b/internal/signaling/handler.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "strings" + "sync" "time" "github.com/koenbollen/logging" @@ -18,13 +19,14 @@ import ( const MaxConnectionTime = 1 * time.Hour -func Handler(ctx context.Context, store stores.Store, cloudflare *cloudflare.CredentialsClient) http.HandlerFunc { +func Handler(ctx context.Context, store stores.Store, cloudflare *cloudflare.CredentialsClient) (*sync.WaitGroup, http.HandlerFunc) { manager := &TimeoutManager{ Store: store, } go manager.Run(ctx) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wg := &sync.WaitGroup{} + return wg, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() logger := logging.GetLogger(ctx) logger.Debug("upgrading connection") @@ -48,6 +50,9 @@ func Handler(ctx context.Context, store stores.Store, cloudflare *cloudflare.Cre util.ErrorAndAbort(w, r, http.StatusBadRequest, "", err) } + wg.Add(1) + defer wg.Done() + peer := &Peer{ store: store, conn: conn,