diff --git a/tools/lambda-promtail/lambda-promtail/main.go b/tools/lambda-promtail/lambda-promtail/main.go index 0e0df1e880041..08143c97ac484 100644 --- a/tools/lambda-promtail/lambda-promtail/main.go +++ b/tools/lambda-promtail/lambda-promtail/main.go @@ -7,6 +7,7 @@ import ( "fmt" "net/url" "os" + "regexp" "strconv" "strings" @@ -25,19 +26,23 @@ const ( maxErrMsgLen = 1024 - invalidExtraLabelsError = "invalid value for environment variable EXTRA_LABELS. Expected a comma separated list with an even number of entries. " + invalidExtraLabelsError = "invalid value for environment variable EXTRA_LABELS. Expected a comma separated list with an even number of entries. " + invalidEvenExtraHeadersError = "invalid value for environment variable EXTRA_HTTP_HEADERS. Expected a comma separated list with an even number of entries." + invalidRegexExtraHeadersError = "invalid value for environment variable EXTRA_HTTP_HEADERS. Header key must conform with regex" ) var ( - writeAddress *url.URL - username, password, extraLabelsRaw, dropLabelsRaw, tenantID, bearerToken string - keepStream bool - batchSize int - s3Clients map[string]*s3.Client - extraLabels model.LabelSet - dropLabels []model.LabelName - skipTlsVerify bool - printLogLine bool + writeAddress *url.URL + username, password, extraLabelsRaw, dropLabelsRaw, tenantID, bearerToken, extraHeadersRaw string + keepStream bool + batchSize int + s3Clients map[string]*s3.Client + extraLabels model.LabelSet + dropLabels []model.LabelName + skipTlsVerify bool + printLogLine bool + extraHeaders map[string]string + httpHeaderKeyRegex = regexp.MustCompile("^[-A-Za-z0-9]+$") ) func setupArguments() { @@ -106,6 +111,12 @@ func setupArguments() { printLogLine = false } s3Clients = make(map[string]*s3.Client) + + extraHeadersRaw = os.Getenv("EXTRA_HTTP_HEADERS") + extraHeaders, err = parseExtraHeaders(extraHeadersRaw) + if err != nil { + panic(err) + } } func parseExtraLabels(extraLabelsRaw string, omitPrefix bool) (model.LabelSet, error) { @@ -190,6 +201,28 @@ func checkEventType(ev map[string]interface{}) (interface{}, error) { return nil, fmt.Errorf("unknown event type!") } +func parseExtraHeaders(extraHeadersRaw string) (map[string]string, error) { + extractedHeaders := make(map[string]string) + extraHeadersSplit := strings.Split(extraHeadersRaw, ",") + + if len(extraHeadersRaw) < 1 { + return extractedHeaders, nil + } + + if len(extraHeadersSplit)%2 != 0 { + return nil, fmt.Errorf(invalidEvenExtraHeadersError) + } + + for i := 0; i < len(extraHeadersSplit); i += 2 { + if !httpHeaderKeyRegex.MatchString(extraHeadersSplit[i]) { + return nil, fmt.Errorf("%s %s: %s", invalidRegexExtraHeadersError, httpHeaderKeyRegex.String(), extraHeadersSplit[i]) + } + extractedHeaders[extraHeadersSplit[i]] = extraHeadersSplit[i+1] + } + + return extractedHeaders, nil +} + func handler(ctx context.Context, ev map[string]interface{}) error { lvl, ok := os.LookupEnv("LOG_LEVEL") if !ok { diff --git a/tools/lambda-promtail/lambda-promtail/main_test.go b/tools/lambda-promtail/lambda-promtail/main_test.go index 5f70044e2b5ea..c9051550d2e15 100644 --- a/tools/lambda-promtail/lambda-promtail/main_test.go +++ b/tools/lambda-promtail/lambda-promtail/main_test.go @@ -64,3 +64,33 @@ func TestLambdaPromtail_TestDropLabels(t *testing.T) { require.NotContains(t, modifiedLabels, model.LabelName("A1")) require.Contains(t, modifiedLabels, model.LabelName("B2")) } + +func TestLambdaPromtail_ExtraHeadersValid(t *testing.T) { + extraHeaders, err := parseExtraHeaders("X-Custom-Header,This!sATota\\yCu$t0mHe4der,My-Server_WantsThis,What_ever could go here?,Expected4Entry,yLKc+QSB5VF/Gp3VPN7oOxa98yxWMxeHOAo+CW6trow=") + require.Nil(t, err) + require.Len(t, extraHeaders, 3) + require.Equal(t, extraHeaders["X-Custom-Header"], "This!sATota\\yCu$t0mHe4der") + require.Equal(t, extraHeaders["My-Server_WantsThis"], "What_ever could go here?") + require.Equal(t, extraHeaders["Expected4Entry"], "yLKc+QSB5VF/Gp3VPN7oOxa98yxWMxeHOAo+CW6trow=") +} + +func TestLambdaPromtail_ExtraHeadersInvalidHeaderKey(t *testing.T) { + extraHeaders, err := parseExtraHeaders("Th.s_Shou|d-Fa!l,a") + require.Nil(t, extraHeaders) + require.ErrorContains(t, err, "HTTP header key is invalid:") + extraHeaders, err = parseExtraHeaders("Also Not Valid ,b") + require.Nil(t, extraHeaders) + require.ErrorContains(t, err, "HTTP header key is invalid:") +} + +func TestLambdaPromtail_ExtraHeadersMissingValue(t *testing.T) { + extraHeaders, err := parseExtraHeaders("A,a,B,b,C,c,D") + require.Nil(t, extraHeaders) + require.Errorf(t, err, invalidExtraHeadersError) +} + +func TestLambdaPromtail_TestParseHeadersNoneProvided(t *testing.T) { + extraLabels, err := parseExtraHeaders("") + require.Len(t, extraLabels, 0) + require.Nil(t, err) +} diff --git a/tools/lambda-promtail/lambda-promtail/promtail.go b/tools/lambda-promtail/lambda-promtail/promtail.go index c1d01c36b174b..0812b1fef1690 100644 --- a/tools/lambda-promtail/lambda-promtail/promtail.go +++ b/tools/lambda-promtail/lambda-promtail/promtail.go @@ -198,6 +198,16 @@ func (c *promtailClient) send(ctx context.Context, buf []byte) (int, error) { req.Header.Set("Authorization", "Bearer "+bearerToken) } + if len(extraHeaders) > 0 { + for key, value := range extraHeaders { + if req.Header.Get(key) != "" { + level.Warn(*c.log).Log("msg", fmt.Sprintf("Not overwriting duplicate header key %s with value: %s! Check EXTRA_HTTP_HEADERS for duplicate keys.", key, value)) + continue + } + req.Header.Set(key, value) + } + } + resp, err := c.http.Do(req.WithContext(ctx)) if err != nil { return -1, err