diff --git a/README.md b/README.md index 93e60c37d..aac70ef3b 100644 --- a/README.md +++ b/README.md @@ -775,6 +775,7 @@ Setting up Splunk with a _HTTP Event Collector_ - `ignore_tag_prefix_list`: (optional) Choose which tags to be ignored by the Splunk Pump. Keep in mind that the tag name and value are hyphenated. Type: Type: String Array `[] string`. Default value is `[]` - `enable_batch`: If this is set to `true`, pump is going to send the analytics records in batch to Splunk. Type: Boolean. Default value is `false`. - `max_content_length`: Max content length in bytes to be sent in batch requests. It should match the `max_content_length` configured in Splunk. If the purged analytics records size don't reach the amount of bytes, they're send anyways in each `purge_loop`. Type: Integer. Default value is 838860800 (~ 800 MB), the same default value as Splunk config. +- `max_retries`: Max number of retries if failed to send requests to splunk HEC. Default value is `0` (no retries after failure). Connections, network, timeouts, temporary, too many requests and internal server errors are all considered retryable. ###### JSON / Conf File @@ -791,6 +792,7 @@ Setting up Splunk with a _HTTP Event Collector_ "obfuscate_api_keys": true, "obfuscate_api_keys_length": 10, "enable_batch":true, + "max_retries": 2, "fields": [ "method", "host", diff --git a/http-retry/http-retry.go b/http-retry/http-retry.go new file mode 100644 index 000000000..3a6bdd3e7 --- /dev/null +++ b/http-retry/http-retry.go @@ -0,0 +1,150 @@ +package httpretry + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/sirupsen/logrus" +) + +type BackoffHTTPRetry struct { + logger *logrus.Entry + httpclient *http.Client + errMsg string + maxRetries uint64 +} + +type ( + conError interface{ ConnectionError() bool } + tempError interface{ Temporary() bool } + timeoutError interface{ Timeout() bool } +) + +// NewBackoffRetry Creates an exponential backoff retry to use httpClient for connections. Will retry if a temporary error or +// 5xx or 429 status code in response. +func NewBackoffRetry(errMsg string, maxRetries uint64, httpClient *http.Client, logger *logrus.Entry) *BackoffHTTPRetry { + return &BackoffHTTPRetry{errMsg: errMsg, maxRetries: maxRetries, httpclient: httpClient, logger: logger} +} + +func (s *BackoffHTTPRetry) Send(req *http.Request) error { + var reqBody []byte + if req.Body != nil { + var err error + reqBody, err = io.ReadAll(req.Body) + if err != nil { + s.logger.WithError(err).Error("Failed to read req body") + return err + } + req.Body.Close() // closing the original body + } + + opFn := func() error { + // recreating the request body from the buffer for each retry as if first attempt fails and + // a new conn is created (keep alive disabled on server for example) the req body has already been read, + // resulting in "http: ContentLength=X with Body length Y" error + req.Body = io.NopCloser(bytes.NewBuffer(reqBody)) + + t := time.Now() + resp, err := s.httpclient.Do(req) + s.logger.Debugf("Req %s took %s", req.URL, time.Since(t)) + + if err != nil { + return s.handleErr(err) + } + defer func() { + // read all response and discard so http client can + // reuse connection as per doc on Response.Body + _, err := io.Copy(io.Discard, resp.Body) + if err != nil { + s.logger.WithError(err).Error("Failed to read and discard resp body") + } + resp.Body.Close() + }() + + if resp.StatusCode == http.StatusOK { + return nil + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + s.logger.WithError(err).Error("Failed to read resp body") + // attempt retry + return err + } + + err = fmt.Errorf("got status code %d and response '%s'", resp.StatusCode, body) + + // server error or rate limit hit - attempt retry + if resp.StatusCode >= http.StatusInternalServerError || resp.StatusCode == http.StatusTooManyRequests { + return err + } + + // any other error treat as permanent (i.e. auth error, invalid request) and don't retry + return backoff.Permanent(err) + } + + return backoff.RetryNotify(opFn, backoff.WithMaxRetries(backoff.NewExponentialBackOff(), s.maxRetries), func(err error, t time.Duration) { + s.logger.WithError(err).Warningf("%s retrying in %s", s.errMsg, t) + }) +} + +func (s *BackoffHTTPRetry) handleErr(err error) error { + if isErrorRetryable(err) { + return err + } + // permanent error - don't retry + return backoff.Permanent(err) +} + +func isErrorRetryable(err error) bool { + if err == nil { + return false + } + + var ( + conErr conError + tempErr tempError + timeoutErr timeoutError + urlErr *url.Error + netOpErr *net.OpError + ) + + switch { + case errors.As(err, &conErr) && conErr.ConnectionError(): + return true + case strings.Contains(err.Error(), "connection reset"): + return true + case errors.As(err, &urlErr): + // Refused connections should be retried as the service may not yet be + // running on the port. Go TCP dial considers refused connections as + // not temporary. + if strings.Contains(urlErr.Error(), "connection refused") { + return true + } + return isErrorRetryable(errors.Unwrap(urlErr)) + case errors.As(err, &netOpErr): + // Network dial, or temporary network errors are always retryable. + if strings.EqualFold(netOpErr.Op, "dial") || netOpErr.Temporary() { + return true + } + return isErrorRetryable(errors.Unwrap(netOpErr)) + case errors.As(err, &tempErr) && tempErr.Temporary(): + // Fallback to the generic temporary check, with temporary errors + // retryable. + return true + case errors.As(err, &timeoutErr) && timeoutErr.Timeout(): + // Fallback to the generic timeout check, with timeout errors + // retryable. + return true + } + + return false +} diff --git a/pumps/splunk.go b/pumps/splunk.go index 993f82c87..1c7c27218 100644 --- a/pumps/splunk.go +++ b/pumps/splunk.go @@ -10,9 +10,9 @@ import ( "net/url" "strings" - "github.com/mitchellh/mapstructure" - "github.com/TykTechnologies/tyk-pump/analytics" + retry "github.com/TykTechnologies/tyk-pump/http-retry" + "github.com/mitchellh/mapstructure" ) const ( @@ -35,8 +35,8 @@ type SplunkClient struct { Token string CollectorURL string TLSSkipVerify bool - - httpClient *http.Client + httpClient *http.Client + retry *retry.BackoffHTTPRetry } // SplunkPump is a Tyk Pump driver for Splunk. @@ -85,6 +85,8 @@ type SplunkPumpConfig struct { // the amount of bytes, they're send anyways in each `purge_loop`. Default value is 838860800 // (~ 800 MB), the same default value as Splunk config. BatchMaxContentLength int `json:"batch_max_content_length" mapstructure:"batch_max_content_length"` + // MaxRetries the maximum amount of retries if failed to send requests to splunk HEC. Default value is `0` + MaxRetries uint64 `json:"max_retries" mapstructure:"max_retries"` } // New initializes a new pump. @@ -124,6 +126,11 @@ func (p *SplunkPump) Init(config interface{}) error { p.config.BatchMaxContentLength = maxContentLength } + if p.config.MaxRetries > 0 { + p.log.Infof("%d max retries", p.config.MaxRetries) + } + + p.client.retry = retry.NewBackoffRetry("Failed writing data to Splunk", p.config.MaxRetries, p.client.httpClient, p.log) p.log.Info(p.GetName() + " Initialized") return nil @@ -153,15 +160,6 @@ func (p *SplunkPump) WriteData(ctx context.Context, data []interface{}) error { var batchBuffer bytes.Buffer - fnSendBytes := func(data []byte) error { - _, errSend := p.client.Send(ctx, data) - if errSend != nil { - p.log.Error("Error writing data to Splunk ", errSend) - return errSend - } - return nil - } - for _, v := range data { decoded := v.(analytics.AnalyticsRecord) apiKey := decoded.APIKey @@ -253,14 +251,14 @@ func (p *SplunkPump) WriteData(ctx context.Context, data []interface{}) error { if p.config.EnableBatch { //if we're batching and the len of our data is already bigger than max_content_length, we send the data and reset the buffer if batchBuffer.Len()+len(data) > p.config.BatchMaxContentLength { - if err := fnSendBytes(batchBuffer.Bytes()); err != nil { + if err := p.send(ctx, batchBuffer.Bytes()); err != nil { return err } batchBuffer.Reset() } batchBuffer.Write(data) } else { - if err := fnSendBytes(data); err != nil { + if err := p.send(ctx, data); err != nil { return err } } @@ -268,7 +266,7 @@ func (p *SplunkPump) WriteData(ctx context.Context, data []interface{}) error { //this if is for data remaining in the buffer when len(buffer) is lower than max_content_length if p.config.EnableBatch && batchBuffer.Len() > 0 { - if err := fnSendBytes(batchBuffer.Bytes()); err != nil { + if err := p.send(ctx, batchBuffer.Bytes()); err != nil { return err } batchBuffer.Reset() @@ -311,15 +309,15 @@ func NewSplunkClient(token string, collectorURL string, skipVerify bool, certFil return c, nil } -// Send sends an event to the Splunk HTTP Event Collector interface. -func (c *SplunkClient) Send(ctx context.Context, data []byte) (*http.Response, error) { - +func (p *SplunkPump) send(ctx context.Context, data []byte) error { reader := bytes.NewReader(data) - req, err := http.NewRequest("POST", c.CollectorURL, reader) + req, err := http.NewRequest(http.MethodPost, p.client.CollectorURL, reader) if err != nil { - return nil, err + return err } req = req.WithContext(ctx) - req.Header.Add(authHeaderName, authHeaderPrefix+c.Token) - return c.httpClient.Do(req) + req.Header.Add(authHeaderName, authHeaderPrefix+p.client.Token) + + p.log.Debugf("Sending %d bytes to splunk", len(data)) + return p.client.retry.Send(req) } diff --git a/pumps/splunk_test.go b/pumps/splunk_test.go index 22aeb000e..81a5ebbc7 100644 --- a/pumps/splunk_test.go +++ b/pumps/splunk_test.go @@ -26,13 +26,16 @@ type splunkStatus struct { Len int `json:"len"` } type testHandler struct { - test *testing.T - batched bool - - responses []splunkStatus + test *testing.T + batched bool + returnErrors int + responses []splunkStatus + reqCount int } func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.reqCount++ + authHeaderValue := r.Header.Get("authorization") if authHeaderValue == "" { h.test.Fatal("Auth header is empty") @@ -48,6 +51,18 @@ func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { h.test.Fatal("Couldn't ready body") } + r.Body.Close() + + if h.returnErrors >= h.reqCount { + fmt.Println("returning err.......") + w.WriteHeader(http.StatusInternalServerError) + _, err := w.Write([]byte("splunk internal error")) + if err != nil { + h.test.Fatalf("Failed to write response got error %v", err) + } + return + } + status := splunkStatus{Text: "Success", Code: 0} if !h.batched { event := make(map[string]interface{}) @@ -79,6 +94,116 @@ func TestSplunkInit(t *testing.T) { } } +func Test_SplunkBackoffRetry(t *testing.T) { + go t.Run("max_retries=1", func(t *testing.T) { + handler := &testHandler{test: t, batched: false, returnErrors: 1} + server := httptest.NewUnstartedServer(handler) + server.Config.SetKeepAlivesEnabled(false) + server.Start() + + defer server.Close() + + pmp := SplunkPump{} + cfg := make(map[string]interface{}) + cfg["collector_token"] = testToken + cfg["max_retries"] = 1 + cfg["collector_url"] = server.URL + cfg["ssl_insecure_skip_verify"] = true + + if err := pmp.Init(cfg); err != nil { + t.Errorf("Error initializing pump %v", err) + return + } + + keys := make([]interface{}, 1) + + keys[0] = analytics.AnalyticsRecord{OrgID: "1", APIID: "123", Path: "/test-path", Method: "POST", TimeStamp: time.Now()} + + if errWrite := pmp.WriteData(context.TODO(), keys); errWrite != nil { + t.Error("Error writing to splunk pump:", errWrite.Error()) + return + } + + assert.Equal(t, 1, len(handler.responses)) + assert.Equal(t, 2, handler.reqCount) + + response := handler.responses[0] + + assert.Equal(t, "Success", response.Text) + assert.Equal(t, int32(0), response.Code) + }) + + t.Run("max_retries=0", func(t *testing.T) { + handler := &testHandler{test: t, batched: false, returnErrors: 1} + server := httptest.NewUnstartedServer(handler) + server.Config.SetKeepAlivesEnabled(false) + server.Start() + + defer server.Close() + + pmp := SplunkPump{} + cfg := make(map[string]interface{}) + cfg["collector_token"] = testToken + cfg["max_retries"] = 0 + cfg["collector_url"] = server.URL + cfg["ssl_insecure_skip_verify"] = true + + if err := pmp.Init(cfg); err != nil { + t.Errorf("Error initializing pump %v", err) + return + } + + keys := make([]interface{}, 1) + + keys[0] = analytics.AnalyticsRecord{OrgID: "1", APIID: "123", Path: "/test-path", Method: "POST", TimeStamp: time.Now()} + + if errWrite := pmp.WriteData(context.TODO(), keys); errWrite == nil { + t.Error("Error expected writing to splunk pump, got nil") + return + } + + assert.Equal(t, 1, handler.reqCount) + }) + + t.Run("max_retries=3", func(t *testing.T) { + handler := &testHandler{test: t, batched: false, returnErrors: 2} + server := httptest.NewUnstartedServer(handler) + server.Config.SetKeepAlivesEnabled(false) + server.Start() + + defer server.Close() + + pmp := SplunkPump{} + cfg := make(map[string]interface{}) + cfg["collector_token"] = testToken + cfg["max_retries"] = 3 + cfg["collector_url"] = server.URL + cfg["ssl_insecure_skip_verify"] = true + + if err := pmp.Init(cfg); err != nil { + t.Errorf("Error initializing pump %v", err) + return + } + + keys := make([]interface{}, 1) + + keys[0] = analytics.AnalyticsRecord{OrgID: "1", APIID: "123", Path: "/test-path", Method: "POST", TimeStamp: time.Now()} + + if errWrite := pmp.WriteData(context.TODO(), keys); errWrite != nil { + t.Error("Error writing to splunk pump:", errWrite.Error()) + return + } + + assert.Equal(t, 1, len(handler.responses)) + assert.Equal(t, 3, handler.reqCount) + + response := handler.responses[0] + + assert.Equal(t, "Success", response.Text) + assert.Equal(t, int32(0), response.Code) + }) +} + func Test_SplunkWriteData(t *testing.T) { handler := &testHandler{test: t, batched: false} server := httptest.NewServer(handler) @@ -112,6 +237,7 @@ func Test_SplunkWriteData(t *testing.T) { assert.Equal(t, "Success", response.Text) assert.Equal(t, int32(0), response.Code) } + func Test_SplunkWriteDataBatch(t *testing.T) { handler := &testHandler{test: t, batched: true} server := httptest.NewServer(handler) @@ -148,7 +274,6 @@ func Test_SplunkWriteDataBatch(t *testing.T) { assert.Equal(t, getEventBytes(keys[:2]), handler.responses[0].Len) assert.Equal(t, getEventBytes(keys[2:]), handler.responses[1].Len) - } // getEventBytes returns the bytes amount of the marshalled events struct