diff --git a/cmd/signaling/main.go b/cmd/signaling/main.go index 261a1a9..4afdad2 100644 --- a/cmd/signaling/main.go +++ b/cmd/signaling/main.go @@ -3,6 +3,7 @@ package main import ( "context" "math/rand" + "net" "net/http" "os" "os/signal" @@ -49,7 +50,7 @@ func main() { ) go credentialsClient.Run(ctx) - mux := internal.Signaling(ctx, store, credentialsClient) + mux, cleanup := internal.Signaling(ctx, store, credentialsClient) cors := cors.Default() handler := logging.Middleware(cors.Handler(mux), logger) @@ -64,6 +65,10 @@ func main() { Addr: addr, Handler: handler, + BaseContext: func(net.Listener) context.Context { + return ctx + }, + ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, IdleTimeout: 650 * time.Second, @@ -77,13 +82,16 @@ func main() { logger.Info("listening", zap.String("addr", addr)) <-ctx.Done() - if flushed != nil { - <-flushed - } + logger.Info("shutting down") - ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if err := server.Shutdown(ctx); err != nil { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + if err := server.Shutdown(shutdownCtx); err != nil { logger.Fatal("failed to shutdown server", zap.Error(err)) } + + cleanup() + if flushed != nil { + <-flushed + } } diff --git a/internal/signaling.go b/internal/signaling.go index c8614be..dfa8f6a 100644 --- a/internal/signaling.go +++ b/internal/signaling.go @@ -11,10 +11,15 @@ import ( "github.com/poki/netlib/internal/util" ) -func Signaling(ctx context.Context, store stores.Store, credentialsClient *cloudflare.CredentialsClient) http.Handler { +func Signaling(ctx context.Context, store stores.Store, credentialsClient *cloudflare.CredentialsClient) (http.Handler, func()) { mux := http.NewServeMux() - mux.Handle("/v0/signaling", signaling.Handler(ctx, store, credentialsClient)) + openConnections, signaling := signaling.Handler(ctx, store, credentialsClient) + + cleanup := func() { + openConnections.Wait() + } + mux.Handle("/v0/signaling", signaling) hasCredentials := uint32(0) mux.HandleFunc("/ready", func(w http.ResponseWriter, r *http.Request) { @@ -43,5 +48,5 @@ func Signaling(ctx context.Context, store stores.Store, credentialsClient *cloud mux.HandleFunc("/health", healthCheck) mux.HandleFunc("/", healthCheck) - return mux + return mux, cleanup } diff --git a/internal/signaling/handler.go b/internal/signaling/handler.go index 3ac63aa..d3ca35f 100644 --- a/internal/signaling/handler.go +++ b/internal/signaling/handler.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "strings" + "sync" "time" "github.com/koenbollen/logging" @@ -18,13 +19,14 @@ import ( const MaxConnectionTime = 1 * time.Hour -func Handler(ctx context.Context, store stores.Store, cloudflare *cloudflare.CredentialsClient) http.HandlerFunc { +func Handler(ctx context.Context, store stores.Store, cloudflare *cloudflare.CredentialsClient) (*sync.WaitGroup, http.HandlerFunc) { manager := &TimeoutManager{ Store: store, } go manager.Run(ctx) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wg := &sync.WaitGroup{} + return wg, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() logger := logging.GetLogger(ctx) logger.Debug("upgrading connection") @@ -48,6 +50,9 @@ func Handler(ctx context.Context, store stores.Store, cloudflare *cloudflare.Cre util.ErrorAndAbort(w, r, http.StatusBadRequest, "", err) } + wg.Add(1) + defer wg.Done() + peer := &Peer{ store: store, conn: conn,