diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index a8f238ae0a3..bfffe887b3f 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -50,6 +50,7 @@ import ( "github.com/onflow/flow-go/engine/access/rest" commonrest "github.com/onflow/flow-go/engine/access/rest/common" "github.com/onflow/flow-go/engine/access/rest/router" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" rpcConnection "github.com/onflow/flow-go/engine/access/rpc/connection" @@ -227,8 +228,9 @@ func DefaultAccessNodeConfig() *AccessNodeConfig { IdleTimeout: rest.DefaultIdleTimeout, MaxRequestSize: commonrest.DefaultMaxRequestSize, }, - MaxMsgSize: grpcutils.DefaultMaxMsgSize, - CompressorName: grpcutils.NoCompressor, + MaxMsgSize: grpcutils.DefaultMaxMsgSize, + CompressorName: grpcutils.NoCompressor, + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), }, stateStreamConf: statestreambackend.Config{ MaxExecutionDataMsgSize: grpcutils.DefaultMaxMsgSize, @@ -1450,6 +1452,11 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { "registerdb-pruning-threshold", defaultConfig.registerDBPruneThreshold, fmt.Sprintf("specifies the number of blocks below the latest stored block height to keep in register db. default: %d", defaultConfig.registerDBPruneThreshold)) + + flags.DurationVar(&builder.rpcConf.WebSocketConfig.InactivityTimeout, + "websocket-inactivity-timeout", + defaultConfig.rpcConf.WebSocketConfig.InactivityTimeout, + "specifies the duration a WebSocket connection can remain open without any active subscriptions before being automatically closed") }).ValidateFlags(func() error { if builder.supportsObserver && (builder.PublicNetworkConfig.BindAddress == cmd.NotSet || builder.PublicNetworkConfig.BindAddress == "") { return errors.New("public-network-address must be set if supports-observer is true") diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index 1bb6a8c04bb..ea24d064262 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -169,7 +169,6 @@ type ObserverServiceConfig struct { registerCacheSize uint programCacheSize uint registerDBPruneThreshold uint64 - websocketConfig websockets.Config } // DefaultObserverServiceConfig defines all the default values for the ObserverServiceConfig @@ -200,8 +199,9 @@ func DefaultObserverServiceConfig() *ObserverServiceConfig { IdleTimeout: rest.DefaultIdleTimeout, MaxRequestSize: commonrest.DefaultMaxRequestSize, }, - MaxMsgSize: grpcutils.DefaultMaxMsgSize, - CompressorName: grpcutils.NoCompressor, + MaxMsgSize: grpcutils.DefaultMaxMsgSize, + CompressorName: grpcutils.NoCompressor, + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), }, stateStreamConf: statestreambackend.Config{ MaxExecutionDataMsgSize: grpcutils.DefaultMaxMsgSize, @@ -254,7 +254,6 @@ func DefaultObserverServiceConfig() *ObserverServiceConfig { registerCacheSize: 0, programCacheSize: 0, registerDBPruneThreshold: pruner.DefaultThreshold, - websocketConfig: websockets.NewDefaultWebsocketConfig(), } } @@ -814,6 +813,11 @@ func (builder *ObserverServiceBuilder) extraFlags() { "registerdb-pruning-threshold", defaultConfig.registerDBPruneThreshold, fmt.Sprintf("specifies the number of blocks below the latest stored block height to keep in register db. default: %d", defaultConfig.registerDBPruneThreshold)) + + flags.DurationVar(&builder.rpcConf.WebSocketConfig.InactivityTimeout, + "websocket-inactivity-timeout", + defaultConfig.rpcConf.WebSocketConfig.InactivityTimeout, + "specifies the duration a WebSocket connection can remain open without any active subscriptions before being automatically closed") }).ValidateFlags(func() error { if builder.executionDataSyncEnabled { if builder.executionDataConfig.FetchTimeout <= 0 { diff --git a/engine/access/rest/websockets/config.go b/engine/access/rest/websockets/config.go index 1cb1a74c183..1f2a2ce99b7 100644 --- a/engine/access/rest/websockets/config.go +++ b/engine/access/rest/websockets/config.go @@ -29,16 +29,24 @@ const ( // if the client is slow or unresponsive. This prevents resource exhaustion // and allows the server to gracefully handle timeouts for delayed writes. WriteWait = 10 * time.Second + + // DefaultInactivityTimeout is the default duration a WebSocket connection can remain open without any active subscriptions + // before being automatically closed + DefaultInactivityTimeout time.Duration = 1 * time.Minute ) type Config struct { MaxSubscriptionsPerConnection uint64 MaxResponsesPerSecond uint64 + // InactivityTimeout specifies the duration a WebSocket connection can remain open without any active subscriptions + // before being automatically closed + InactivityTimeout time.Duration } func NewDefaultWebsocketConfig() Config { return Config{ MaxSubscriptionsPerConnection: 1000, MaxResponsesPerSecond: 1000, + InactivityTimeout: DefaultInactivityTimeout, } } diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index bffa57350c0..15c187fc650 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -159,7 +159,7 @@ func (c *Controller) HandleConnection(ctx context.Context) { err := c.configureKeepalive() if err != nil { - c.logger.Error().Err(err).Msg("error configuring connection") + c.logger.Error().Err(err).Msg("error configuring keepalive connection") return } @@ -237,8 +237,16 @@ func (c *Controller) keepalive(ctx context.Context) error { } // writeMessages reads a messages from multiplexed stream and passes them on to a client WebSocket connection. -// The multiplexed stream channel is filled by data providers +// The multiplexed stream channel is filled by data providers. +// The function tracks the last message sent and periodically checks for inactivity. +// If no messages are sent within InactivityTimeout and no active data providers exist, +// the connection will be closed. func (c *Controller) writeMessages(ctx context.Context) error { + inactivityTicker := time.NewTicker(c.config.InactivityTimeout / 10) + defer inactivityTicker.Stop() + + lastMessageSentAt := time.Now() + defer func() { // drain the channel as some providers may still send data to it after this routine shutdowns // so, in order to not run into deadlock there should be at least 1 reader on the channel @@ -257,6 +265,11 @@ func (c *Controller) writeMessages(ctx context.Context) error { return nil } + // Specifies a timeout for the write operation. If the write + // isn't completed within this duration, it fails with a timeout error. + // SetWriteDeadline ensures the write operation does not block indefinitely + // if the client is slow or unresponsive. This prevents resource exhaustion + // and allows the server to gracefully handle timeouts for delayed writes. if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil { return fmt.Errorf("failed to set the write deadline: %w", err) } @@ -264,6 +277,18 @@ func (c *Controller) writeMessages(ctx context.Context) error { if err := c.conn.WriteJSON(message); err != nil { return err } + + lastMessageSentAt = time.Now() + + case <-inactivityTicker.C: + hasNoActiveSubscriptions := c.dataProviders.Size() == 0 + exceedsInactivityTimeout := time.Since(lastMessageSentAt) > c.config.InactivityTimeout + if hasNoActiveSubscriptions && exceedsInactivityTimeout { + c.logger.Debug(). + Dur("timeout", c.config.InactivityTimeout). + Msg("connection inactive, closing due to timeout") + return fmt.Errorf("no recent activity for %v", c.config.InactivityTimeout) + } } } } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 9707dbb8205..1a52f79b516 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -7,13 +7,12 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" @@ -809,6 +808,37 @@ func (s *WsControllerSuite) TestControllerShutdown() { conn.AssertExpectations(t) }) + + s.T().Run("Inactivity tracking", func(t *testing.T) { + t.Parallel() + + conn := connmock.NewWebsocketConnection(t) + conn.On("Close").Return(nil).Once() + conn.On("SetReadDeadline", mock.Anything).Return(nil).Once() + conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() + + factory := dpmock.NewDataProviderFactory(t) + // Mock with short inactivity timeout for testing + wsConfig := s.wsConfig + + wsConfig.InactivityTimeout = 50 * time.Millisecond + controller := NewWebSocketController(s.logger, wsConfig, conn, factory) + + conn. + On("ReadJSON", mock.Anything). + Return(func(interface{}) error { + // waiting more than InactivityTimeout to make sure that read message routine busy and do not return + // an error before than inactivity tracker initiate shut down + <-time.After(wsConfig.InactivityTimeout) + return websocket.ErrCloseSent + }). + Once() + + controller.HandleConnection(context.Background()) + time.Sleep(wsConfig.InactivityTimeout) + + conn.AssertExpectations(t) + }) } func (s *WsControllerSuite) TestKeepaliveRoutine() {