diff --git a/Makefile b/Makefile index 23d1627d112..2578fffe4b6 100644 --- a/Makefile +++ b/Makefile @@ -203,10 +203,10 @@ generate-mocks: install-mock-generators mockery --name 'API' --dir="./engine/protocol" --case=underscore --output="./engine/protocol/mock" --outpkg="mock" mockery --name '.*' --dir="./engine/access/state_stream" --case=underscore --output="./engine/access/state_stream/mock" --outpkg="mock" mockery --name 'BlockTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" + mockery --name 'DataProvider' --dir="./engine/access/rest/websockets/data_provider" --case=underscore --output="./engine/access/rest/websockets/data_provider/mock" --outpkg="mock" mockery --name 'ExecutionDataTracker' --dir="./engine/access/subscription" --case=underscore --output="./engine/access/subscription/mock" --outpkg="mock" mockery --name 'ConnectionFactory' --dir="./engine/access/rpc/connection" --case=underscore --output="./engine/access/rpc/connection/mock" --outpkg="mock" mockery --name 'Communicator' --dir="./engine/access/rpc/backend" --case=underscore --output="./engine/access/rpc/backend/mock" --outpkg="mock" - mockery --name '.*' --dir=model/fingerprint --case=underscore --output="./model/fingerprint/mock" --outpkg="mock" mockery --name 'ExecForkActor' --structname 'ExecForkActorMock' --dir=module/mempool/consensus/mock/ --case=underscore --output="./module/mempool/consensus/mock/" --outpkg="mock" mockery --name '.*' --dir=engine/verification/fetcher/ --case=underscore --output="./engine/verification/fetcher/mock" --outpkg="mockfetcher" diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index 21fe924ac0d..1bb6a8c04bb 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -45,6 +45,7 @@ import ( restapiproxy "github.com/onflow/flow-go/engine/access/rest/apiproxy" commonrest "github.com/onflow/flow-go/engine/access/rest/common" "github.com/onflow/flow-go/engine/access/rest/router" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" rpcConnection "github.com/onflow/flow-go/engine/access/rpc/connection" @@ -168,6 +169,7 @@ type ObserverServiceConfig struct { registerCacheSize uint programCacheSize uint registerDBPruneThreshold uint64 + websocketConfig websockets.Config } // DefaultObserverServiceConfig defines all the default values for the ObserverServiceConfig @@ -252,6 +254,7 @@ func DefaultObserverServiceConfig() *ObserverServiceConfig { registerCacheSize: 0, programCacheSize: 0, registerDBPruneThreshold: pruner.DefaultThreshold, + websocketConfig: websockets.NewDefaultWebsocketConfig(), } } diff --git a/cmd/util/cmd/run-script/cmd.go b/cmd/util/cmd/run-script/cmd.go index 1f24d2599c2..171f97e76b7 100644 --- a/cmd/util/cmd/run-script/cmd.go +++ b/cmd/util/cmd/run-script/cmd.go @@ -16,6 +16,7 @@ import ( "github.com/onflow/flow-go/cmd/util/ledger/util" "github.com/onflow/flow-go/cmd/util/ledger/util/registers" "github.com/onflow/flow-go/engine/access/rest" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/engine/access/subscription" "github.com/onflow/flow-go/engine/execution/computation" @@ -169,6 +170,7 @@ func run(*cobra.Command, []string) { metrics.NewNoopCollector(), nil, backend.Config{}, + websockets.NewDefaultWebsocketConfig(), ) if err != nil { log.Fatal().Err(err).Msg("failed to create server") diff --git a/engine/access/handle_irrecoverable_state_test.go b/engine/access/handle_irrecoverable_state_test.go index 303a542339a..456c5cd97fd 100644 --- a/engine/access/handle_irrecoverable_state_test.go +++ b/engine/access/handle_irrecoverable_state_test.go @@ -23,6 +23,7 @@ import ( accessmock "github.com/onflow/flow-go/engine/access/mock" "github.com/onflow/flow-go/engine/access/rest" "github.com/onflow/flow-go/engine/access/rest/router" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -109,6 +110,7 @@ func (suite *IrrecoverableStateTestSuite) SetupTest() { RestConfig: rest.Config{ ListenAddress: unittest.DefaultAddress, }, + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/access/integration_unsecure_grpc_server_test.go b/engine/access/integration_unsecure_grpc_server_test.go index f99805687ba..3c4aeca97d4 100644 --- a/engine/access/integration_unsecure_grpc_server_test.go +++ b/engine/access/integration_unsecure_grpc_server_test.go @@ -21,6 +21,7 @@ import ( "github.com/onflow/flow-go/engine" "github.com/onflow/flow-go/engine/access/index" accessmock "github.com/onflow/flow-go/engine/access/mock" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" "github.com/onflow/flow-go/engine/access/state_stream" @@ -138,6 +139,7 @@ func (suite *SameGRPCPortTestSuite) SetupTest() { UnsecureGRPCListenAddr: unittest.DefaultAddress, SecureGRPCListenAddr: unittest.DefaultAddress, HTTPListenAddr: unittest.DefaultAddress, + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), } blockCount := 5 diff --git a/engine/access/rest/router/router.go b/engine/access/rest/router/router.go index 102f9797639..a2d81cb0a58 100644 --- a/engine/access/rest/router/router.go +++ b/engine/access/rest/router/router.go @@ -2,6 +2,7 @@ package router import ( "fmt" + "net/http" "regexp" "strings" @@ -10,8 +11,9 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/common/middleware" - "github.com/onflow/flow-go/engine/access/rest/http" + flowhttp "github.com/onflow/flow-go/engine/access/rest/http" "github.com/onflow/flow-go/engine/access/rest/http/models" + "github.com/onflow/flow-go/engine/access/rest/websockets" legacyws "github.com/onflow/flow-go/engine/access/rest/websockets/legacy" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -54,7 +56,7 @@ func (b *RouterBuilder) AddRestRoutes( ) *RouterBuilder { linkGenerator := models.NewLinkGeneratorImpl(b.v1SubRouter) for _, r := range Routes { - h := http.NewHandler(b.logger, backend, r.Handler, linkGenerator, chain, maxRequestSize) + h := flowhttp.NewHandler(b.logger, backend, r.Handler, linkGenerator, chain, maxRequestSize) b.v1SubRouter. Methods(r.Method). Path(r.Pattern). @@ -64,8 +66,8 @@ func (b *RouterBuilder) AddRestRoutes( return b } -// AddWsLegacyRoutes adds WebSocket routes to the router. -func (b *RouterBuilder) AddWsLegacyRoutes( +// AddLegacyWebsocketsRoutes adds WebSocket routes to the router. +func (b *RouterBuilder) AddLegacyWebsocketsRoutes( stateStreamApi state_stream.API, chain flow.Chain, stateStreamConfig backend.Config, @@ -84,6 +86,23 @@ func (b *RouterBuilder) AddWsLegacyRoutes( return b } +func (b *RouterBuilder) AddWebsocketsRoute( + chain flow.Chain, + config websockets.Config, + streamApi state_stream.API, + streamConfig backend.Config, + maxRequestSize int64, +) *RouterBuilder { + handler := websockets.NewWebSocketHandler(b.logger, config, chain, streamApi, streamConfig, maxRequestSize) + b.v1SubRouter. + Methods(http.MethodGet). + Path("/ws"). + Name("ws"). + Handler(handler) + + return b +} + func (b *RouterBuilder) Build() *mux.Router { return b.router } diff --git a/engine/access/rest/router/router_test_helpers.go b/engine/access/rest/router/router_test_helpers.go index 59a1d27ea4d..94968d978fb 100644 --- a/engine/access/rest/router/router_test_helpers.go +++ b/engine/access/rest/router/router_test_helpers.go @@ -135,7 +135,7 @@ func ExecuteRequest(req *http.Request, backend access.API) *httptest.ResponseRec return rr } -func ExecuteWsRequest(req *http.Request, stateStreamApi state_stream.API, responseRecorder *TestHijackResponseRecorder, chain flow.Chain) { +func ExecuteLegacyWsRequest(req *http.Request, stateStreamApi state_stream.API, responseRecorder *TestHijackResponseRecorder, chain flow.Chain) { restCollector := metrics.NewNoopCollector() config := backend.Config{ @@ -147,7 +147,7 @@ func ExecuteWsRequest(req *http.Request, stateStreamApi state_stream.API, respon router := NewRouterBuilder( unittest.Logger(), restCollector, - ).AddWsLegacyRoutes( + ).AddLegacyWebsocketsRoutes( stateStreamApi, chain, config, common.DefaultMaxRequestSize, ).Build() diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index d25044a60a5..0e582d0bee4 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -9,6 +9,7 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/router" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/model/flow" @@ -42,12 +43,15 @@ func NewServer(serverAPI access.API, restCollector module.RestMetrics, stateStreamApi state_stream.API, stateStreamConfig backend.Config, + wsConfig websockets.Config, ) (*http.Server, error) { builder := router.NewRouterBuilder(logger, restCollector).AddRestRoutes(serverAPI, chain, config.MaxRequestSize) if stateStreamApi != nil { - builder.AddWsLegacyRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize) + builder.AddLegacyWebsocketsRoutes(stateStreamApi, chain, stateStreamConfig, config.MaxRequestSize) } + builder.AddWebsocketsRoute(chain, wsConfig, stateStreamApi, stateStreamConfig, config.MaxRequestSize) + c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedHeaders: []string{"*"}, diff --git a/engine/access/rest/websockets/config.go b/engine/access/rest/websockets/config.go new file mode 100644 index 00000000000..7f563ba94b9 --- /dev/null +++ b/engine/access/rest/websockets/config.go @@ -0,0 +1,21 @@ +package websockets + +import ( + "time" +) + +type Config struct { + MaxSubscriptionsPerConnection uint64 + MaxResponsesPerSecond uint64 + SendMessageTimeout time.Duration + MaxRequestSize int64 +} + +func NewDefaultWebsocketConfig() Config { + return Config{ + MaxSubscriptionsPerConnection: 1000, + MaxResponsesPerSecond: 1000, + SendMessageTimeout: 10 * time.Second, + MaxRequestSize: 1024, + } +} diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go new file mode 100644 index 00000000000..fe873f5f61c --- /dev/null +++ b/engine/access/rest/websockets/controller.go @@ -0,0 +1,212 @@ +package websockets + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + + dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/utils/concurrentmap" +) + +type Controller struct { + logger zerolog.Logger + config Config + conn *websocket.Conn + communicationChannel chan interface{} + dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] + dataProvidersFactory *dp.Factory +} + +func NewWebSocketController( + logger zerolog.Logger, + config Config, + streamApi state_stream.API, + streamConfig backend.Config, + conn *websocket.Conn, +) *Controller { + return &Controller{ + logger: logger.With().Str("component", "websocket-controller").Logger(), + config: config, + conn: conn, + communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? + dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), + dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig), + } +} + +// HandleConnection manages the WebSocket connection, adding context and error handling. +func (c *Controller) HandleConnection(ctx context.Context) { + //TODO: configure the connection with ping-pong and deadlines + //TODO: spin up a response limit tracker routine + go c.readMessagesFromClient(ctx) + c.writeMessagesToClient(ctx) +} + +// writeMessagesToClient reads a messages from communication channel and passes them on to a client WebSocket connection. +// The communication channel is filled by data providers. Besides, the response limit tracker is involved in +// write message regulation +func (c *Controller) writeMessagesToClient(ctx context.Context) { + //TODO: can it run forever? maybe we should cancel the ctx in the reader routine + for { + select { + case <-ctx.Done(): + return + case msg := <-c.communicationChannel: + // TODO: handle 'response per second' limits + + err := c.conn.WriteJSON(msg) + if err != nil { + c.logger.Error().Err(err).Msg("error writing to connection") + } + } + } +} + +// readMessagesFromClient continuously reads messages from a client WebSocket connection, +// processes each message, and handles actions based on the message type. +func (c *Controller) readMessagesFromClient(ctx context.Context) { + defer c.shutdownConnection() + + for { + select { + case <-ctx.Done(): + c.logger.Info().Msg("context canceled, stopping read message loop") + return + default: + msg, err := c.readMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) { + return + } + c.logger.Warn().Err(err).Msg("error reading message from client") + return + } + + baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg) + if err != nil { + c.logger.Debug().Err(err).Msg("error parsing and validating client message") + return + } + + if err := c.handleAction(ctx, validatedMsg); err != nil { + c.logger.Warn().Err(err).Str("action", baseMsg.Action).Msg("error handling action") + } + } + } +} + +func (c *Controller) readMessage() (json.RawMessage, error) { + var message json.RawMessage + if err := c.conn.ReadJSON(&message); err != nil { + return nil, fmt.Errorf("error reading JSON from client: %w", err) + } + return message, nil +} + +func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.BaseMessageRequest, interface{}, error) { + var baseMsg models.BaseMessageRequest + if err := json.Unmarshal(message, &baseMsg); err != nil { + return models.BaseMessageRequest{}, nil, fmt.Errorf("error unmarshalling base message: %w", err) + } + + var validatedMsg interface{} + switch baseMsg.Action { + case "subscribe": + var subscribeMsg models.SubscribeMessageRequest + if err := json.Unmarshal(message, &subscribeMsg); err != nil { + return baseMsg, nil, fmt.Errorf("error unmarshalling subscribe message: %w", err) + } + //TODO: add validation logic for `topic` field + validatedMsg = subscribeMsg + + case "unsubscribe": + var unsubscribeMsg models.UnsubscribeMessageRequest + if err := json.Unmarshal(message, &unsubscribeMsg); err != nil { + return baseMsg, nil, fmt.Errorf("error unmarshalling unsubscribe message: %w", err) + } + validatedMsg = unsubscribeMsg + + case "list_subscriptions": + var listMsg models.ListSubscriptionsMessageRequest + if err := json.Unmarshal(message, &listMsg); err != nil { + return baseMsg, nil, fmt.Errorf("error unmarshalling list subscriptions message: %w", err) + } + validatedMsg = listMsg + + default: + c.logger.Debug().Str("action", baseMsg.Action).Msg("unknown action type") + return baseMsg, nil, fmt.Errorf("unknown action type: %s", baseMsg.Action) + } + + return baseMsg, validatedMsg, nil +} + +func (c *Controller) handleAction(ctx context.Context, message interface{}) error { + switch msg := message.(type) { + case models.SubscribeMessageRequest: + c.handleSubscribe(ctx, msg) + case models.UnsubscribeMessageRequest: + c.handleUnsubscribe(ctx, msg) + case models.ListSubscriptionsMessageRequest: + c.handleListSubscriptions(ctx, msg) + default: + return fmt.Errorf("unknown message type: %T", msg) + } + return nil +} + +func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { + dp := c.dataProvidersFactory.NewDataProvider(c.communicationChannel, msg.Topic) + c.dataProviders.Add(dp.ID(), dp) + dp.Run(ctx) + + //TODO: return OK response to client + c.communicationChannel <- msg +} + +func (c *Controller) handleUnsubscribe(_ context.Context, msg models.UnsubscribeMessageRequest) { + id, err := uuid.Parse(msg.ID) + if err != nil { + c.logger.Debug().Err(err).Msg("error parsing message ID") + //TODO: return an error response to client + c.communicationChannel <- err + return + } + + dp, ok := c.dataProviders.Get(id) + if ok { + dp.Close() + c.dataProviders.Remove(id) + } +} + +func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) { + //TODO: return a response to client +} + +func (c *Controller) shutdownConnection() { + defer close(c.communicationChannel) + defer func(conn *websocket.Conn) { + if err := c.conn.Close(); err != nil { + c.logger.Error().Err(err).Msg("error closing connection") + } + }(c.conn) + + err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { + dp.Close() + return nil + }) + if err != nil { + c.logger.Error().Err(err).Msg("error closing data provider") + } + + c.dataProviders.Clear() +} diff --git a/engine/access/rest/websockets/data_provider/blocks.go b/engine/access/rest/websockets/data_provider/blocks.go new file mode 100644 index 00000000000..01b4d07d2e7 --- /dev/null +++ b/engine/access/rest/websockets/data_provider/blocks.go @@ -0,0 +1,61 @@ +package data_provider + +import ( + "context" + + "github.com/google/uuid" + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/access/state_stream" +) + +type MockBlockProvider struct { + id uuid.UUID + topicChan chan<- interface{} // provider is not the one who is responsible to close this channel + topic string + logger zerolog.Logger + stopProviderFunc context.CancelFunc + streamApi state_stream.API +} + +func NewMockBlockProvider( + ch chan<- interface{}, + topic string, + logger zerolog.Logger, + streamApi state_stream.API, +) *MockBlockProvider { + return &MockBlockProvider{ + id: uuid.New(), + topicChan: ch, + topic: topic, + logger: logger.With().Str("component", "block-provider").Logger(), + stopProviderFunc: nil, + streamApi: streamApi, + } +} + +func (p *MockBlockProvider) Run(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) + p.stopProviderFunc = cancel + + for { + select { + case <-ctx.Done(): + return + case p.topicChan <- "block{height: 42}": + return + } + } +} + +func (p *MockBlockProvider) ID() uuid.UUID { + return p.id +} + +func (p *MockBlockProvider) Topic() string { + return p.topic +} + +func (p *MockBlockProvider) Close() { + p.stopProviderFunc() +} diff --git a/engine/access/rest/websockets/data_provider/factory.go b/engine/access/rest/websockets/data_provider/factory.go new file mode 100644 index 00000000000..6a2658b1b95 --- /dev/null +++ b/engine/access/rest/websockets/data_provider/factory.go @@ -0,0 +1,31 @@ +package data_provider + +import ( + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" +) + +type Factory struct { + logger zerolog.Logger + streamApi state_stream.API + streamConfig backend.Config +} + +func NewDataProviderFactory(logger zerolog.Logger, streamApi state_stream.API, streamConfig backend.Config) *Factory { + return &Factory{ + logger: logger, + streamApi: streamApi, + streamConfig: streamConfig, + } +} + +func (f *Factory) NewDataProvider(ch chan<- interface{}, topic string) DataProvider { + switch topic { + case "blocks": + return NewMockBlockProvider(ch, topic, f.logger, f.streamApi) + default: + return nil + } +} diff --git a/engine/access/rest/websockets/data_provider/mock/data_provider.go b/engine/access/rest/websockets/data_provider/mock/data_provider.go new file mode 100644 index 00000000000..4a2a22a44a0 --- /dev/null +++ b/engine/access/rest/websockets/data_provider/mock/data_provider.go @@ -0,0 +1,78 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + uuid "github.com/google/uuid" +) + +// DataProvider is an autogenerated mock type for the DataProvider type +type DataProvider struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *DataProvider) Close() { + _m.Called() +} + +// ID provides a mock function with given fields: +func (_m *DataProvider) ID() uuid.UUID { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ID") + } + + var r0 uuid.UUID + if rf, ok := ret.Get(0).(func() uuid.UUID); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(uuid.UUID) + } + } + + return r0 +} + +// Run provides a mock function with given fields: ctx +func (_m *DataProvider) Run(ctx context.Context) { + _m.Called(ctx) +} + +// Topic provides a mock function with given fields: +func (_m *DataProvider) Topic() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Topic") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// NewDataProvider creates a new instance of DataProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDataProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *DataProvider { + mock := &DataProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/engine/access/rest/websockets/data_provider/provider.go b/engine/access/rest/websockets/data_provider/provider.go new file mode 100644 index 00000000000..ce2914140ba --- /dev/null +++ b/engine/access/rest/websockets/data_provider/provider.go @@ -0,0 +1,14 @@ +package data_provider + +import ( + "context" + + "github.com/google/uuid" +) + +type DataProvider interface { + Run(ctx context.Context) + ID() uuid.UUID + Topic() string + Close() +} diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go new file mode 100644 index 00000000000..247890c2a62 --- /dev/null +++ b/engine/access/rest/websockets/handler.go @@ -0,0 +1,70 @@ +package websockets + +import ( + "context" + "net/http" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/access/rest/common" + "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/model/flow" +) + +type Handler struct { + *common.HttpHandler + + logger zerolog.Logger + websocketConfig Config + streamApi state_stream.API + streamConfig backend.Config +} + +var _ http.Handler = (*Handler)(nil) + +func NewWebSocketHandler( + logger zerolog.Logger, + config Config, + chain flow.Chain, + streamApi state_stream.API, + streamConfig backend.Config, + maxRequestSize int64, +) *Handler { + return &Handler{ + HttpHandler: common.NewHttpHandler(logger, chain, maxRequestSize), + websocketConfig: config, + logger: logger, + streamApi: streamApi, + streamConfig: streamConfig, + } +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + //TODO: change to accept topic instead of URL + logger := h.HttpHandler.Logger.With().Str("websocket_subscribe_url", r.URL.String()).Logger() + + err := h.HttpHandler.VerifyRequest(w, r) + if err != nil { + // VerifyRequest sets the response error before returning + logger.Debug().Err(err).Msg("error validating websocket request") + return + } + + upgrader := websocket.Upgrader{ + // allow all origins by default, operators can override using a proxy + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.HttpHandler.ErrorHandler(w, common.NewRestError(http.StatusInternalServerError, "webSocket upgrade error: ", err), logger) + return + } + + controller := NewWebSocketController(logger, h.websocketConfig, h.streamApi, h.streamConfig, conn) + controller.HandleConnection(context.TODO()) +} diff --git a/engine/access/rest/websockets/handler_test.go b/engine/access/rest/websockets/handler_test.go new file mode 100644 index 00000000000..6b9cce06572 --- /dev/null +++ b/engine/access/rest/websockets/handler_test.go @@ -0,0 +1,86 @@ +package websockets_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/onflow/flow-go/engine/access/rest/websockets" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +var ( + chainID = flow.Testnet +) + +type WsHandlerSuite struct { + suite.Suite + + logger zerolog.Logger + handler *websockets.Handler + wsConfig websockets.Config + streamApi *streammock.API + streamConfig backend.Config +} + +func (s *WsHandlerSuite) SetupTest() { + s.logger = unittest.Logger() + s.wsConfig = websockets.NewDefaultWebsocketConfig() + s.streamApi = streammock.NewAPI(s.T()) + s.streamConfig = backend.Config{} + s.handler = websockets.NewWebSocketHandler(s.logger, s.wsConfig, chainID.Chain(), s.streamApi, s.streamConfig, 1024) +} + +func TestWsHandlerSuite(t *testing.T) { + suite.Run(t, new(WsHandlerSuite)) +} + +func ClientConnection(url string) (*websocket.Conn, *http.Response, error) { + wsURL := "ws" + strings.TrimPrefix(url, "http") + return websocket.DefaultDialer.Dial(wsURL, nil) +} + +func (s *WsHandlerSuite) TestSubscribeRequest() { + s.Run("Happy path", func() { + server := httptest.NewServer(s.handler) + defer server.Close() + + conn, _, err := ClientConnection(server.URL) + defer func(conn *websocket.Conn) { + err := conn.Close() + require.NoError(s.T(), err) + }(conn) + require.NoError(s.T(), err) + + args := map[string]interface{}{ + "start_block_height": 10, + } + body := models.SubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, + Topic: "blocks", + Arguments: args, + } + bodyJSON, err := json.Marshal(body) + require.NoError(s.T(), err) + + err = conn.WriteMessage(websocket.TextMessage, bodyJSON) + require.NoError(s.T(), err) + + _, msg, err := conn.ReadMessage() + require.NoError(s.T(), err) + + actualMsg := strings.Trim(string(msg), "\n\"\\ ") + require.Equal(s.T(), "block{height: 42}", actualMsg) + }) +} diff --git a/engine/access/rest/websockets/legacy/routes/subscribe_events_test.go b/engine/access/rest/websockets/legacy/routes/subscribe_events_test.go index c4353cecae2..a423bd4622f 100644 --- a/engine/access/rest/websockets/legacy/routes/subscribe_events_test.go +++ b/engine/access/rest/websockets/legacy/routes/subscribe_events_test.go @@ -252,7 +252,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { time.Sleep(1 * time.Second) respRecorder.Close() }() - router.ExecuteWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) + router.ExecuteLegacyWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) requireResponse(s.T(), respRecorder, expectedEventsResponses) }) } @@ -264,7 +264,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil, 1, nil) require.NoError(s.T(), err) respRecorder := router.NewTestHijackResponseRecorder() - router.ExecuteWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) + router.ExecuteLegacyWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) requireError(s.T(), respRecorder, "can only provide either block ID or start height") }) @@ -289,7 +289,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil, 1, nil) require.NoError(s.T(), err) respRecorder := router.NewTestHijackResponseRecorder() - router.ExecuteWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) + router.ExecuteLegacyWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) requireError(s.T(), respRecorder, "stream encountered an error: subscription error") }) @@ -298,7 +298,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, []string{"foo"}, nil, nil, 1, nil) require.NoError(s.T(), err) respRecorder := router.NewTestHijackResponseRecorder() - router.ExecuteWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) + router.ExecuteLegacyWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) requireError(s.T(), respRecorder, "invalid event type format") }) @@ -323,7 +323,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil, 1, nil) require.NoError(s.T(), err) respRecorder := router.NewTestHijackResponseRecorder() - router.ExecuteWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) + router.ExecuteLegacyWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) requireError(s.T(), respRecorder, "subscription channel closed") }) } diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go new file mode 100644 index 00000000000..f56d62fda8f --- /dev/null +++ b/engine/access/rest/websockets/models/base_message.go @@ -0,0 +1,13 @@ +package models + +// BaseMessageRequest represents a base structure for incoming messages. +type BaseMessageRequest struct { + Action string `json:"action"` // Action type of the request +} + +// BaseMessageResponse represents a base structure for outgoing messages. +type BaseMessageResponse struct { + Action string `json:"action,omitempty"` // Action type of the response + Success bool `json:"success"` // Indicates success or failure + ErrorMessage string `json:"error_message,omitempty"` // Error message, if any +} diff --git a/engine/access/rest/websockets/models/list_subscriptions.go b/engine/access/rest/websockets/models/list_subscriptions.go new file mode 100644 index 00000000000..26174869585 --- /dev/null +++ b/engine/access/rest/websockets/models/list_subscriptions.go @@ -0,0 +1,13 @@ +package models + +// ListSubscriptionsMessageRequest represents a request to list active subscriptions. +type ListSubscriptionsMessageRequest struct { + BaseMessageRequest +} + +// ListSubscriptionsMessageResponse is the structure used to respond to list_subscriptions requests. +// It contains a list of active subscriptions for the current WebSocket connection. +type ListSubscriptionsMessageResponse struct { + BaseMessageResponse + Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` +} diff --git a/engine/access/rest/websockets/models/subscribe.go b/engine/access/rest/websockets/models/subscribe.go new file mode 100644 index 00000000000..993bd63b811 --- /dev/null +++ b/engine/access/rest/websockets/models/subscribe.go @@ -0,0 +1,15 @@ +package models + +// SubscribeMessageRequest represents a request to subscribe to a topic. +type SubscribeMessageRequest struct { + BaseMessageRequest + Topic string `json:"topic"` // Topic to subscribe to + Arguments map[string]interface{} `json:"arguments"` // Additional arguments for subscription +} + +// SubscribeMessageResponse represents the response to a subscription request. +type SubscribeMessageResponse struct { + BaseMessageResponse + Topic string `json:"topic"` // Topic of the subscription + ID string `json:"id"` // Unique subscription ID +} diff --git a/engine/access/rest/websockets/models/subscription_entry.go b/engine/access/rest/websockets/models/subscription_entry.go new file mode 100644 index 00000000000..d3f2b352bb7 --- /dev/null +++ b/engine/access/rest/websockets/models/subscription_entry.go @@ -0,0 +1,7 @@ +package models + +// SubscriptionEntry represents an active subscription entry. +type SubscriptionEntry struct { + Topic string `json:"topic,omitempty"` // Topic of the subscription + ID string `json:"id,omitempty"` // Unique subscription ID +} diff --git a/engine/access/rest/websockets/models/unsubscribe.go b/engine/access/rest/websockets/models/unsubscribe.go new file mode 100644 index 00000000000..2024bb922e0 --- /dev/null +++ b/engine/access/rest/websockets/models/unsubscribe.go @@ -0,0 +1,13 @@ +package models + +// UnsubscribeMessageRequest represents a request to unsubscribe from a topic. +type UnsubscribeMessageRequest struct { + BaseMessageRequest + ID string `json:"id"` // Unique subscription ID +} + +// UnsubscribeMessageResponse represents the response to an unsubscription request. +type UnsubscribeMessageResponse struct { + BaseMessageResponse + ID string `json:"id"` // Unique subscription ID +} diff --git a/engine/access/rest_api_test.go b/engine/access/rest_api_test.go index 64dab073c1d..651adb41a63 100644 --- a/engine/access/rest_api_test.go +++ b/engine/access/rest_api_test.go @@ -24,6 +24,7 @@ import ( "github.com/onflow/flow-go/engine/access/rest/common" "github.com/onflow/flow-go/engine/access/rest/http/request" "github.com/onflow/flow-go/engine/access/rest/router" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -137,6 +138,7 @@ func (suite *RestAPITestSuite) SetupTest() { RestConfig: rest.Config{ ListenAddress: unittest.DefaultAddress, }, + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/access/rpc/engine.go b/engine/access/rpc/engine.go index 145e3d62143..37b60b1a4d3 100644 --- a/engine/access/rpc/engine.go +++ b/engine/access/rpc/engine.go @@ -14,6 +14,7 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/consensus/hotstuff/model" "github.com/onflow/flow-go/engine/access/rest" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc/backend" "github.com/onflow/flow-go/engine/access/state_stream" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -38,10 +39,11 @@ type Config struct { CollectionAddr string // the address of the upstream collection node HistoricalAccessAddrs string // the list of all access nodes from previous spork - BackendConfig backend.Config // configurable options for creating Backend - RestConfig rest.Config // the REST server configuration - MaxMsgSize uint // GRPC max message size - CompressorName string // GRPC compressor name + BackendConfig backend.Config // configurable options for creating Backend + RestConfig rest.Config // the REST server configuration + MaxMsgSize uint // GRPC max message size + CompressorName string // GRPC compressor name + WebSocketConfig websockets.Config } // Engine exposes the server with a simplified version of the Access API. @@ -75,7 +77,8 @@ type Engine struct { type Option func(*RPCEngineBuilder) // NewBuilder returns a new RPC engine builder. -func NewBuilder(log zerolog.Logger, +func NewBuilder( + log zerolog.Logger, state protocol.State, config Config, chainID flow.ChainID, @@ -240,8 +243,16 @@ func (e *Engine) serveREST(ctx irrecoverable.SignalerContext, ready component.Re e.log.Info().Str("rest_api_address", e.config.RestConfig.ListenAddress).Msg("starting REST server on address") - r, err := rest.NewServer(e.restHandler, e.config.RestConfig, e.log, e.chain, e.restCollector, e.stateStreamBackend, - e.stateStreamConfig) + r, err := rest.NewServer( + e.restHandler, + e.config.RestConfig, + e.log, + e.chain, + e.restCollector, + e.stateStreamBackend, + e.stateStreamConfig, + e.config.WebSocketConfig, + ) if err != nil { e.log.Err(err).Msg("failed to initialize the REST server") ctx.Throw(err) diff --git a/engine/access/rpc/rate_limit_test.go b/engine/access/rpc/rate_limit_test.go index 622b06e3f54..7148cdfefad 100644 --- a/engine/access/rpc/rate_limit_test.go +++ b/engine/access/rpc/rate_limit_test.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc/status" accessmock "github.com/onflow/flow-go/engine/access/mock" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc/backend" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/model/flow" @@ -115,6 +116,7 @@ func (suite *RateLimitTestSuite) SetupTest() { UnsecureGRPCListenAddr: unittest.DefaultAddress, SecureGRPCListenAddr: unittest.DefaultAddress, HTTPListenAddr: unittest.DefaultAddress, + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/access/secure_grpcr_test.go b/engine/access/secure_grpcr_test.go index cc1d1a75cc8..6ffa8f8d324 100644 --- a/engine/access/secure_grpcr_test.go +++ b/engine/access/secure_grpcr_test.go @@ -19,6 +19,7 @@ import ( "github.com/onflow/crypto" accessmock "github.com/onflow/flow-go/engine/access/mock" + "github.com/onflow/flow-go/engine/access/rest/websockets" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -110,6 +111,7 @@ func (suite *SecureGRPCTestSuite) SetupTest() { UnsecureGRPCListenAddr: unittest.DefaultAddress, SecureGRPCListenAddr: unittest.DefaultAddress, HTTPListenAddr: unittest.DefaultAddress, + WebSocketConfig: websockets.NewDefaultWebsocketConfig(), } // generate a server certificate that will be served by the GRPC server diff --git a/engine/common/worker/worker_builder_test.go b/engine/common/worker/worker_builder_test.go index c08da0769c3..09aebe1cc41 100644 --- a/engine/common/worker/worker_builder_test.go +++ b/engine/common/worker/worker_builder_test.go @@ -14,6 +14,7 @@ import ( "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/mempool/queue" "github.com/onflow/flow-go/module/metrics" + "github.com/onflow/flow-go/utils/concurrentmap" "github.com/onflow/flow-go/utils/unittest" ) @@ -115,7 +116,7 @@ func TestWorkerPool_TwoWorkers_ConcurrentEvents(t *testing.T) { } q := queue.NewHeroStore(uint32(size), unittest.Logger(), metrics.NewNoopCollector()) - distributedEvents := unittest.NewProtectedMap[string, struct{}]() + distributedEvents := concurrentmap.New[string, struct{}]() allEventsDistributed := sync.WaitGroup{} allEventsDistributed.Add(size) diff --git a/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go b/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go index c43b7435f55..8debc74e7d7 100644 --- a/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go +++ b/insecure/integration/functional/test/gossipsub/scoring/ihave_spam_test.go @@ -19,6 +19,7 @@ import ( "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/p2p" p2ptest "github.com/onflow/flow-go/network/p2p/test" + "github.com/onflow/flow-go/utils/concurrentmap" "github.com/onflow/flow-go/utils/unittest" ) @@ -36,7 +37,7 @@ func TestGossipSubIHaveBrokenPromises_Below_Threshold(t *testing.T) { sporkId := unittest.IdentifierFixture() blockTopic := channels.TopicFromChannel(channels.PushBlocks, sporkId) - receivedIWants := unittest.NewProtectedMap[string, struct{}]() + receivedIWants := concurrentmap.New[string, struct{}]() idProvider := unittest.NewUpdatableIDProvider(flow.IdentityList{}) spammer := corruptlibp2p.NewGossipSubRouterSpammerWithRpcInspector(t, sporkId, role, idProvider, func(id peer.ID, rpc *corrupt.RPC) error { // override rpc inspector of the spammer node to keep track of the iwants it has received. @@ -188,7 +189,7 @@ func TestGossipSubIHaveBrokenPromises_Above_Threshold(t *testing.T) { sporkId := unittest.IdentifierFixture() blockTopic := channels.TopicFromChannel(channels.PushBlocks, sporkId) - receivedIWants := unittest.NewProtectedMap[string, struct{}]() + receivedIWants := concurrentmap.New[string, struct{}]() idProvider := unittest.NewUpdatableIDProvider(flow.IdentityList{}) spammer := corruptlibp2p.NewGossipSubRouterSpammerWithRpcInspector(t, sporkId, role, idProvider, func(id peer.ID, rpc *corrupt.RPC) error { // override rpc inspector of the spammer node to keep track of the iwants it has received. @@ -437,7 +438,7 @@ func TestGossipSubIHaveBrokenPromises_Above_Threshold(t *testing.T) { func spamIHaveBrokenPromise(t *testing.T, spammer *corruptlibp2p.GossipSubRouterSpammer, topic string, - receivedIWants *unittest.ProtectedMap[string, struct{}], + receivedIWants *concurrentmap.Map[string, struct{}], victimNode p2p.LibP2PNode) { rpcCount := 10 // we can't send more than one iHave per RPC in this test, as each iHave should have a distinct topic, and we only have one subscribed topic. diff --git a/network/p2p/connection/connection_gater_test.go b/network/p2p/connection/connection_gater_test.go index ed8777d3f90..e84bfe0042f 100644 --- a/network/p2p/connection/connection_gater_test.go +++ b/network/p2p/connection/connection_gater_test.go @@ -24,6 +24,7 @@ import ( mockp2p "github.com/onflow/flow-go/network/p2p/mock" p2ptest "github.com/onflow/flow-go/network/p2p/test" "github.com/onflow/flow-go/network/p2p/unicast/stream" + "github.com/onflow/flow-go/utils/concurrentmap" "github.com/onflow/flow-go/utils/unittest" ) @@ -35,7 +36,7 @@ func TestConnectionGating(t *testing.T) { sporkID := unittest.IdentifierFixture() idProvider := mockmodule.NewIdentityProvider(t) // create 2 nodes - node1Peers := unittest.NewProtectedMap[peer.ID, struct{}]() + node1Peers := concurrentmap.New[peer.ID, struct{}]() node1, node1Id := p2ptest.NodeFixture( t, sporkID, @@ -49,7 +50,7 @@ func TestConnectionGating(t *testing.T) { }))) idProvider.On("ByPeerID", node1.ID()).Return(&node1Id, true).Maybe() - node2Peers := unittest.NewProtectedMap[peer.ID, struct{}]() + node2Peers := concurrentmap.New[peer.ID, struct{}]() node2, node2Id := p2ptest.NodeFixture( t, sporkID, @@ -246,7 +247,7 @@ func TestConnectionGater_InterceptUpgrade(t *testing.T) { inbounds := make([]chan string, 0, count) identities := make(flow.IdentityList, 0, count) - disallowedPeerIds := unittest.NewProtectedMap[peer.ID, struct{}]() + disallowedPeerIds := concurrentmap.New[peer.ID, struct{}]() allPeerIds := make(peer.IDSlice, 0, count) idProvider := mockmodule.NewIdentityProvider(t) connectionGater := mockp2p.NewConnectionGater(t) @@ -331,7 +332,7 @@ func TestConnectionGater_Disallow_Integration(t *testing.T) { ids := flow.IdentityList{} inbounds := make([]chan string, 0, 5) - disallowedList := unittest.NewProtectedMap[*flow.Identity, struct{}]() + disallowedList := concurrentmap.New[*flow.Identity, struct{}]() for i := 0; i < count; i++ { handler, inbound := p2ptest.StreamHandlerFixture(t) diff --git a/network/p2p/node/libp2pNode_test.go b/network/p2p/node/libp2pNode_test.go index 9a538bd269b..d53fabb0e17 100644 --- a/network/p2p/node/libp2pNode_test.go +++ b/network/p2p/node/libp2pNode_test.go @@ -24,6 +24,7 @@ import ( p2ptest "github.com/onflow/flow-go/network/p2p/test" "github.com/onflow/flow-go/network/p2p/utils" validator "github.com/onflow/flow-go/network/validator/pubsub" + "github.com/onflow/flow-go/utils/concurrentmap" "github.com/onflow/flow-go/utils/unittest" ) @@ -158,7 +159,7 @@ func TestConnGater(t *testing.T) { sporkID := unittest.IdentifierFixture() idProvider := mockmodule.NewIdentityProvider(t) - node1Peers := unittest.NewProtectedMap[peer.ID, struct{}]() + node1Peers := concurrentmap.New[peer.ID, struct{}]() node1, identity1 := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithConnectionGater(p2ptest.NewConnectionGater(idProvider, func(pid peer.ID) error { if !node1Peers.Has(pid) { return fmt.Errorf("peer id not found: %s", p2plogging.PeerId(pid)) @@ -173,7 +174,7 @@ func TestConnGater(t *testing.T) { node1Info, err := utils.PeerAddressInfo(identity1.IdentitySkeleton) assert.NoError(t, err) - node2Peers := unittest.NewProtectedMap[peer.ID, struct{}]() + node2Peers := concurrentmap.New[peer.ID, struct{}]() node2, identity2 := p2ptest.NodeFixture(t, sporkID, t.Name(), idProvider, p2ptest.WithConnectionGater(p2ptest.NewConnectionGater(idProvider, func(pid peer.ID) error { if !node2Peers.Has(pid) { return fmt.Errorf("id not found: %s", p2plogging.PeerId(pid)) diff --git a/network/test/cohort1/network_test.go b/network/test/cohort1/network_test.go index bffd3ac52b7..723df438960 100644 --- a/network/test/cohort1/network_test.go +++ b/network/test/cohort1/network_test.go @@ -40,6 +40,7 @@ import ( "github.com/onflow/flow-go/network/p2p/unicast/ratelimit" "github.com/onflow/flow-go/network/p2p/utils/ratelimiter" "github.com/onflow/flow-go/network/underlay" + "github.com/onflow/flow-go/utils/concurrentmap" "github.com/onflow/flow-go/utils/unittest" ) @@ -617,7 +618,7 @@ func (suite *NetworkTestSuite) MultiPing(count int) { senderNodeIndex := 0 targetNodeIndex := suite.size - 1 - receivedPayloads := unittest.NewProtectedMap[string, struct{}]() // keep track of unique payloads received. + receivedPayloads := concurrentmap.New[string, struct{}]() // keep track of unique payloads received. // regex to extract the payload from the message regex := regexp.MustCompile(`^hello from: \d`) diff --git a/utils/unittest/protected_map.go b/utils/concurrentmap/concurrent_map.go similarity index 62% rename from utils/unittest/protected_map.go rename to utils/concurrentmap/concurrent_map.go index a2af2f5f513..148c3741428 100644 --- a/utils/unittest/protected_map.go +++ b/utils/concurrentmap/concurrent_map.go @@ -1,36 +1,36 @@ -package unittest +package concurrentmap import "sync" -// ProtectedMap is a thread-safe map. -type ProtectedMap[K comparable, V any] struct { +// Map is a thread-safe map. +type Map[K comparable, V any] struct { mu sync.RWMutex m map[K]V } -// NewProtectedMap returns a new ProtectedMap with the given types -func NewProtectedMap[K comparable, V any]() *ProtectedMap[K, V] { - return &ProtectedMap[K, V]{ +// New returns a new Map with the given types +func New[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{ m: make(map[K]V), } } // Add adds a key-value pair to the map -func (p *ProtectedMap[K, V]) Add(key K, value V) { +func (p *Map[K, V]) Add(key K, value V) { p.mu.Lock() defer p.mu.Unlock() p.m[key] = value } // Remove removes a key-value pair from the map -func (p *ProtectedMap[K, V]) Remove(key K) { +func (p *Map[K, V]) Remove(key K) { p.mu.Lock() defer p.mu.Unlock() delete(p.m, key) } // Has returns true if the map contains the given key -func (p *ProtectedMap[K, V]) Has(key K) bool { +func (p *Map[K, V]) Has(key K) bool { p.mu.RLock() defer p.mu.RUnlock() _, ok := p.m[key] @@ -38,7 +38,7 @@ func (p *ProtectedMap[K, V]) Has(key K) bool { } // Get returns the value for the given key and a boolean indicating if the key was found -func (p *ProtectedMap[K, V]) Get(key K) (V, bool) { +func (p *Map[K, V]) Get(key K) (V, bool) { p.mu.RLock() defer p.mu.RUnlock() value, ok := p.m[key] @@ -47,7 +47,7 @@ func (p *ProtectedMap[K, V]) Get(key K) (V, bool) { // ForEach iterates over the map and calls the given function for each key-value pair. // If the function returns an error, the iteration is stopped and the error is returned. -func (p *ProtectedMap[K, V]) ForEach(fn func(k K, v V) error) error { +func (p *Map[K, V]) ForEach(fn func(k K, v V) error) error { p.mu.RLock() defer p.mu.RUnlock() for k, v := range p.m { @@ -59,8 +59,14 @@ func (p *ProtectedMap[K, V]) ForEach(fn func(k K, v V) error) error { } // Size returns the size of the map. -func (p *ProtectedMap[K, V]) Size() int { +func (p *Map[K, V]) Size() int { p.mu.RLock() defer p.mu.RUnlock() return len(p.m) } + +func (p *Map[K, V]) Clear() { + p.mu.Lock() + defer p.mu.Unlock() + clear(p.m) +}