From 404f169233b72f20462dd18d0dee7ee07247db6e Mon Sep 17 00:00:00 2001 From: Dan Li Date: Mon, 26 Jun 2023 17:24:33 -0700 Subject: [PATCH] fix: add additional request headers to health checks (#123) * fix: add additional request headers to health checks * fix: remove unnecessary check for additionalRequestHeaders == nil --- internal/checks/blockheight.go | 6 +++-- internal/checks/blockheight_test.go | 6 ++--- internal/checks/manager_test.go | 2 +- internal/checks/peers.go | 2 +- internal/checks/peers_test.go | 6 ++--- internal/checks/syncing.go | 2 +- internal/checks/syncing_test.go | 4 +-- internal/client/ethereum_client.go | 41 +++++++++++++++++------------ 8 files changed, 39 insertions(+), 30 deletions(-) diff --git a/internal/checks/blockheight.go b/internal/checks/blockheight.go index 460fc742..ff153aac 100644 --- a/internal/checks/blockheight.go +++ b/internal/checks/blockheight.go @@ -77,7 +77,7 @@ func (c *BlockHeightCheck) initializeWebsockets() error { } func (c *BlockHeightCheck) initializeHTTP() { - httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &client.BasicAuthCredentials{Username: c.upstreamConfig.BasicAuthConfig.Username, Password: c.upstreamConfig.BasicAuthConfig.Password}) + httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &c.upstreamConfig.BasicAuthConfig, &c.upstreamConfig.RequestHeadersConfig) if err != nil { c.metricsContainer.BlockHeightCheckErrors.WithLabelValues(c.upstreamConfig.ID, c.upstreamConfig.HTTPURL, metrics.HTTPInit).Inc() c.setError(err) @@ -112,7 +112,9 @@ func (c *BlockHeightCheck) runCheckHTTP() { header, err := c.httpClient.HeaderByNumber(ctx, nil) if c.blockHeightError = err; c.blockHeightError != nil { + c.logger.Debug("BlockHeightCheck request failed.", zap.Any("upstreamID", c.upstreamConfig.ID), zap.String("httpURL", c.upstreamConfig.HTTPURL), zap.Error(c.blockHeightError)) c.metricsContainer.BlockHeightCheckErrors.WithLabelValues(c.upstreamConfig.ID, c.upstreamConfig.HTTPURL, metrics.HTTPRequest).Inc() + return } @@ -175,7 +177,7 @@ func (c *BlockHeightCheck) subscribeNewHead() error { c.setError(c.webSocketError) } - wsClient, err := c.clientGetter(c.upstreamConfig.WSURL, &client.BasicAuthCredentials{Username: c.upstreamConfig.BasicAuthConfig.Username, Password: c.upstreamConfig.BasicAuthConfig.Password}) + wsClient, err := c.clientGetter(c.upstreamConfig.WSURL, &c.upstreamConfig.BasicAuthConfig, &c.upstreamConfig.RequestHeadersConfig) if err != nil { c.webSocketError = err return err diff --git a/internal/checks/blockheight_test.go b/internal/checks/blockheight_test.go index e4c72dc4..6a4356f3 100644 --- a/internal/checks/blockheight_test.go +++ b/internal/checks/blockheight_test.go @@ -31,7 +31,7 @@ func TestBlockHeightChecker_WS(t *testing.T) { ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(&mockSubscription{}, nil) ethClient.On("HeaderByNumber", mock.Anything, mock.Anything).Return(&types.Header{Number: big.NewInt(int64(maxBlockHeight))}, nil) - mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) { + mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) { return ethClient, nil } @@ -59,7 +59,7 @@ func TestBlockHeightChecker_WSSubscribeFailed(t *testing.T) { ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, errors.New("some error")) ethClient.On("HeaderByNumber", mock.Anything, mock.Anything).Return(&types.Header{Number: big.NewInt(int64(50000))}, nil) - mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) { + mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) { return ethClient, nil } @@ -94,7 +94,7 @@ func TestBlockHeightChecker_HTTP(t *testing.T) { ethClient := mocks.NewEthClient(t) ethClient.On("HeaderByNumber", mock.Anything, mock.Anything).Return(&types.Header{Number: big.NewInt(int64(maxBlockHeight))}, nil) - mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) { + mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) { return ethClient, nil } diff --git a/internal/checks/manager_test.go b/internal/checks/manager_test.go index 1aca74b7..acb539ac 100644 --- a/internal/checks/manager_test.go +++ b/internal/checks/manager_test.go @@ -16,7 +16,7 @@ import ( func TestHealthCheckManager(t *testing.T) { ethereumClient := mocks.NewEthClient(t) - mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) { + mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) { return ethereumClient, nil } diff --git a/internal/checks/peers.go b/internal/checks/peers.go index 075e5931..ec36c79b 100644 --- a/internal/checks/peers.go +++ b/internal/checks/peers.go @@ -47,7 +47,7 @@ func NewPeerChecker( func (c *PeerCheck) Initialize() error { c.logger.Debug("Initializing PeerCheck.", zap.Any("config", c.upstreamConfig)) - httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &client.BasicAuthCredentials{Username: c.upstreamConfig.BasicAuthConfig.Username, Password: c.upstreamConfig.BasicAuthConfig.Password}) + httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &c.upstreamConfig.BasicAuthConfig, &c.upstreamConfig.RequestHeadersConfig) if err != nil { c.Err = err return c.Err diff --git a/internal/checks/peers_test.go b/internal/checks/peers_test.go index f6684c95..2461a8a2 100644 --- a/internal/checks/peers_test.go +++ b/internal/checks/peers_test.go @@ -17,7 +17,7 @@ func TestPeerChecker(t *testing.T) { ethClient := mocks.NewEthClient(t) ethClient.EXPECT().PeerCount(mock.Anything).Return(uint64(4), nil) - mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) { + mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) { return ethClient, nil } @@ -52,7 +52,7 @@ func TestPeerChecker_MethodNotSupported(t *testing.T) { ethClient := mocks.NewEthClient(t) ethClient.EXPECT().PeerCount(mock.Anything).Return(uint64(0), methodNotSupportedError{}) - mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) { + mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) { return ethClient, nil } @@ -69,7 +69,7 @@ func TestPeerChecker_SkipPeerCountCheck(t *testing.T) { ethClient := mocks.NewEthClient(t) ethClient.EXPECT().PeerCount(mock.Anything).Return(uint64(0), nil) - mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) { + mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) { return ethClient, nil } diff --git a/internal/checks/syncing.go b/internal/checks/syncing.go index dd4605f7..1e40ff67 100644 --- a/internal/checks/syncing.go +++ b/internal/checks/syncing.go @@ -44,7 +44,7 @@ func NewSyncingChecker(upstreamConfig *conf.UpstreamConfig, clientGetter client. func (c *SyncingCheck) Initialize() error { c.logger.Debug("Initializing SyncingCheck.", zap.Any("config", c.upstreamConfig)) - httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &client.BasicAuthCredentials{Username: c.upstreamConfig.BasicAuthConfig.Username, Password: c.upstreamConfig.BasicAuthConfig.Password}) + httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &c.upstreamConfig.BasicAuthConfig, &c.upstreamConfig.RequestHeadersConfig) if err != nil { c.Err = err return c.Err diff --git a/internal/checks/syncing_test.go b/internal/checks/syncing_test.go index a05b7206..97cf4370 100644 --- a/internal/checks/syncing_test.go +++ b/internal/checks/syncing_test.go @@ -18,7 +18,7 @@ func TestSyncingChecker(t *testing.T) { ethClient := mocks.NewEthClient(t) ethClient.On("SyncProgress", mock.Anything).Return(ðereum.SyncProgress{}, nil) - mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) { + mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) { return ethClient, nil } @@ -46,7 +46,7 @@ func TestSyncingChecker_MethodNotSupported(t *testing.T) { ethClient := mocks.NewEthClient(t) ethClient.On("SyncProgress", mock.Anything).Return(nil, methodNotSupportedError{}) - mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) { + mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) { return ethClient, nil } diff --git a/internal/client/ethereum_client.go b/internal/client/ethereum_client.go index 82415fcd..be5d0415 100644 --- a/internal/client/ethereum_client.go +++ b/internal/client/ethereum_client.go @@ -12,6 +12,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/rpc" + "github.com/satsuma-data/node-gateway/internal/config" ) const ( @@ -31,19 +32,25 @@ type EthClient interface { SyncProgress(ctx context.Context) (*ethereum.SyncProgress, error) } -type BasicAuthCredentials struct { - Username string - Password string -} +type EthClientGetter func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (EthClient, error) + +func NewEthClient(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (EthClient, error) { + rpcClient, err := getRPCClientWithAuthHeader(url, credentials) + if err != nil { + return nil, err + } + + setAdditionalRequestHeaders(rpcClient, additionalRequestHeaders) -type EthClientGetter func(url string, credentials *BasicAuthCredentials) (EthClient, error) + return ethclient.NewClient(rpcClient), nil +} -func NewEthClient(url string, credentials *BasicAuthCredentials) (EthClient, error) { +func getRPCClientWithAuthHeader(url string, credentials *config.BasicAuthConfig) (*rpc.Client, error) { if credentials == nil || (credentials.Username == "" && credentials.Password == "") { ctx, cancel := context.WithTimeout(context.Background(), clientDialTimeout) defer cancel() - return ethclient.DialContext(ctx, url) + return rpc.DialContext(ctx, url) } parsedURL, err := netUrl.Parse(url) @@ -60,16 +67,16 @@ func NewEthClient(url string, credentials *BasicAuthCredentials) (EthClient, err ctx, cancel := context.WithTimeout(context.Background(), clientDialTimeout) defer cancel() - c, err := rpc.DialContext(ctx, url) + rpcClient, err := rpc.DialContext(ctx, url) if err != nil { return nil, err } encodedCredentials := base64.StdEncoding.EncodeToString([]byte(credentials.Username + ":" + credentials.Password)) - c.SetHeader("Authorization", "Basic "+encodedCredentials) + rpcClient.SetHeader("Authorization", "Basic "+encodedCredentials) - return ethclient.NewClient(c), nil + return rpcClient, nil case "ws", "wss": parsedURL.User = netUrl.UserPassword(credentials.Username, credentials.Password) urlWithUser := parsedURL.String() @@ -77,14 +84,14 @@ func NewEthClient(url string, credentials *BasicAuthCredentials) (EthClient, err ctx, cancel := context.WithTimeout(context.Background(), clientDialTimeout) defer cancel() - c, err := rpc.DialContext(ctx, urlWithUser) - - if err != nil { - return nil, err - } - - return ethclient.NewClient(c), nil + return rpc.DialContext(ctx, urlWithUser) default: return nil, fmt.Errorf("unsupported scheme: %s", parsedURL.Scheme) } } + +func setAdditionalRequestHeaders(c *rpc.Client, additionalRequestHeaders *[]config.RequestHeaderConfig) { + for _, requestHeader := range *additionalRequestHeaders { + c.SetHeader(requestHeader.Key, requestHeader.Value) + } +}