Skip to content

Commit

Permalink
fix: add additional request headers to health checks (#123)
Browse files Browse the repository at this point in the history
* fix: add additional request headers to health checks

* fix: remove unnecessary check for additionalRequestHeaders == nil
  • Loading branch information
dianwen authored Jun 27, 2023
1 parent 5e4111b commit 404f169
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 30 deletions.
6 changes: 4 additions & 2 deletions internal/checks/blockheight.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (c *BlockHeightCheck) initializeWebsockets() error {
}

func (c *BlockHeightCheck) initializeHTTP() {
httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &client.BasicAuthCredentials{Username: c.upstreamConfig.BasicAuthConfig.Username, Password: c.upstreamConfig.BasicAuthConfig.Password})
httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &c.upstreamConfig.BasicAuthConfig, &c.upstreamConfig.RequestHeadersConfig)
if err != nil {
c.metricsContainer.BlockHeightCheckErrors.WithLabelValues(c.upstreamConfig.ID, c.upstreamConfig.HTTPURL, metrics.HTTPInit).Inc()
c.setError(err)
Expand Down Expand Up @@ -112,7 +112,9 @@ func (c *BlockHeightCheck) runCheckHTTP() {
header, err := c.httpClient.HeaderByNumber(ctx, nil)

if c.blockHeightError = err; c.blockHeightError != nil {
c.logger.Debug("BlockHeightCheck request failed.", zap.Any("upstreamID", c.upstreamConfig.ID), zap.String("httpURL", c.upstreamConfig.HTTPURL), zap.Error(c.blockHeightError))
c.metricsContainer.BlockHeightCheckErrors.WithLabelValues(c.upstreamConfig.ID, c.upstreamConfig.HTTPURL, metrics.HTTPRequest).Inc()

return
}

Expand Down Expand Up @@ -175,7 +177,7 @@ func (c *BlockHeightCheck) subscribeNewHead() error {
c.setError(c.webSocketError)
}

wsClient, err := c.clientGetter(c.upstreamConfig.WSURL, &client.BasicAuthCredentials{Username: c.upstreamConfig.BasicAuthConfig.Username, Password: c.upstreamConfig.BasicAuthConfig.Password})
wsClient, err := c.clientGetter(c.upstreamConfig.WSURL, &c.upstreamConfig.BasicAuthConfig, &c.upstreamConfig.RequestHeadersConfig)
if err != nil {
c.webSocketError = err
return err
Expand Down
6 changes: 3 additions & 3 deletions internal/checks/blockheight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestBlockHeightChecker_WS(t *testing.T) {
ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(&mockSubscription{}, nil)
ethClient.On("HeaderByNumber", mock.Anything, mock.Anything).Return(&types.Header{Number: big.NewInt(int64(maxBlockHeight))}, nil)

mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) {
mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) {
return ethClient, nil
}

Expand Down Expand Up @@ -59,7 +59,7 @@ func TestBlockHeightChecker_WSSubscribeFailed(t *testing.T) {
ethClient.On("SubscribeNewHead", mock.Anything, mock.Anything).Return(nil, errors.New("some error"))
ethClient.On("HeaderByNumber", mock.Anything, mock.Anything).Return(&types.Header{Number: big.NewInt(int64(50000))}, nil)

mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) {
mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) {
return ethClient, nil
}

Expand Down Expand Up @@ -94,7 +94,7 @@ func TestBlockHeightChecker_HTTP(t *testing.T) {
ethClient := mocks.NewEthClient(t)
ethClient.On("HeaderByNumber", mock.Anything, mock.Anything).Return(&types.Header{Number: big.NewInt(int64(maxBlockHeight))}, nil)

mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) {
mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) {
return ethClient, nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/checks/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

func TestHealthCheckManager(t *testing.T) {
ethereumClient := mocks.NewEthClient(t)
mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) {
mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) {
return ethereumClient, nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/checks/peers.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func NewPeerChecker(
func (c *PeerCheck) Initialize() error {
c.logger.Debug("Initializing PeerCheck.", zap.Any("config", c.upstreamConfig))

httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &client.BasicAuthCredentials{Username: c.upstreamConfig.BasicAuthConfig.Username, Password: c.upstreamConfig.BasicAuthConfig.Password})
httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &c.upstreamConfig.BasicAuthConfig, &c.upstreamConfig.RequestHeadersConfig)
if err != nil {
c.Err = err
return c.Err
Expand Down
6 changes: 3 additions & 3 deletions internal/checks/peers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestPeerChecker(t *testing.T) {
ethClient := mocks.NewEthClient(t)
ethClient.EXPECT().PeerCount(mock.Anything).Return(uint64(4), nil)

mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) {
mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) {
return ethClient, nil
}

Expand Down Expand Up @@ -52,7 +52,7 @@ func TestPeerChecker_MethodNotSupported(t *testing.T) {
ethClient := mocks.NewEthClient(t)
ethClient.EXPECT().PeerCount(mock.Anything).Return(uint64(0), methodNotSupportedError{})

mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) {
mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) {
return ethClient, nil
}

Expand All @@ -69,7 +69,7 @@ func TestPeerChecker_SkipPeerCountCheck(t *testing.T) {
ethClient := mocks.NewEthClient(t)
ethClient.EXPECT().PeerCount(mock.Anything).Return(uint64(0), nil)

mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) {
mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) {
return ethClient, nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/checks/syncing.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func NewSyncingChecker(upstreamConfig *conf.UpstreamConfig, clientGetter client.
func (c *SyncingCheck) Initialize() error {
c.logger.Debug("Initializing SyncingCheck.", zap.Any("config", c.upstreamConfig))

httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &client.BasicAuthCredentials{Username: c.upstreamConfig.BasicAuthConfig.Username, Password: c.upstreamConfig.BasicAuthConfig.Password})
httpClient, err := c.clientGetter(c.upstreamConfig.HTTPURL, &c.upstreamConfig.BasicAuthConfig, &c.upstreamConfig.RequestHeadersConfig)
if err != nil {
c.Err = err
return c.Err
Expand Down
4 changes: 2 additions & 2 deletions internal/checks/syncing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestSyncingChecker(t *testing.T) {
ethClient := mocks.NewEthClient(t)
ethClient.On("SyncProgress", mock.Anything).Return(&ethereum.SyncProgress{}, nil)

mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) {
mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) {
return ethClient, nil
}

Expand Down Expand Up @@ -46,7 +46,7 @@ func TestSyncingChecker_MethodNotSupported(t *testing.T) {
ethClient := mocks.NewEthClient(t)
ethClient.On("SyncProgress", mock.Anything).Return(nil, methodNotSupportedError{})

mockEthClientGetter := func(url string, credentials *client.BasicAuthCredentials) (client.EthClient, error) {
mockEthClientGetter := func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (client.EthClient, error) {
return ethClient, nil
}

Expand Down
41 changes: 24 additions & 17 deletions internal/client/ethereum_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/rpc"
"github.com/satsuma-data/node-gateway/internal/config"
)

const (
Expand All @@ -31,19 +32,25 @@ type EthClient interface {
SyncProgress(ctx context.Context) (*ethereum.SyncProgress, error)
}

type BasicAuthCredentials struct {
Username string
Password string
}
type EthClientGetter func(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (EthClient, error)

func NewEthClient(url string, credentials *config.BasicAuthConfig, additionalRequestHeaders *[]config.RequestHeaderConfig) (EthClient, error) {
rpcClient, err := getRPCClientWithAuthHeader(url, credentials)
if err != nil {
return nil, err
}

setAdditionalRequestHeaders(rpcClient, additionalRequestHeaders)

type EthClientGetter func(url string, credentials *BasicAuthCredentials) (EthClient, error)
return ethclient.NewClient(rpcClient), nil
}

func NewEthClient(url string, credentials *BasicAuthCredentials) (EthClient, error) {
func getRPCClientWithAuthHeader(url string, credentials *config.BasicAuthConfig) (*rpc.Client, error) {
if credentials == nil || (credentials.Username == "" && credentials.Password == "") {
ctx, cancel := context.WithTimeout(context.Background(), clientDialTimeout)
defer cancel()

return ethclient.DialContext(ctx, url)
return rpc.DialContext(ctx, url)
}

parsedURL, err := netUrl.Parse(url)
Expand All @@ -60,31 +67,31 @@ func NewEthClient(url string, credentials *BasicAuthCredentials) (EthClient, err
ctx, cancel := context.WithTimeout(context.Background(), clientDialTimeout)
defer cancel()

c, err := rpc.DialContext(ctx, url)
rpcClient, err := rpc.DialContext(ctx, url)

if err != nil {
return nil, err
}

encodedCredentials := base64.StdEncoding.EncodeToString([]byte(credentials.Username + ":" + credentials.Password))
c.SetHeader("Authorization", "Basic "+encodedCredentials)
rpcClient.SetHeader("Authorization", "Basic "+encodedCredentials)

return ethclient.NewClient(c), nil
return rpcClient, nil
case "ws", "wss":
parsedURL.User = netUrl.UserPassword(credentials.Username, credentials.Password)
urlWithUser := parsedURL.String()

ctx, cancel := context.WithTimeout(context.Background(), clientDialTimeout)
defer cancel()

c, err := rpc.DialContext(ctx, urlWithUser)

if err != nil {
return nil, err
}

return ethclient.NewClient(c), nil
return rpc.DialContext(ctx, urlWithUser)
default:
return nil, fmt.Errorf("unsupported scheme: %s", parsedURL.Scheme)
}
}

func setAdditionalRequestHeaders(c *rpc.Client, additionalRequestHeaders *[]config.RequestHeaderConfig) {
for _, requestHeader := range *additionalRequestHeaders {
c.SetHeader(requestHeader.Key, requestHeader.Value)
}
}

0 comments on commit 404f169

Please sign in to comment.