From 1e4af2c322987725f0137f9b25cb8499243372b3 Mon Sep 17 00:00:00 2001 From: Marcus Efraimsson Date: Thu, 15 Dec 2022 10:21:35 +0100 Subject: [PATCH] Introduce interface to ease forward/access HTTP headers (#562) A new interface ForwardHTTPHeaders that allows to typeswitch/typecast to that interface instead and simplify the code. From plugin authors perspective we have documentation for Forward OAuth identity for the logged-in user, Forward cookies for the logged-in user and soon for Forward user header and we provides instruction for accessing these forwarded headers by access the request.Headers field. This field is different for different request types: - For QueryDataRequest and CheckHealthRequest there's a map[string]string (they're suppose to hold environment context/metadata - For CallResourceRequest there's a map[string][]string (suppose to hold HTTP headers since it's HTTP proxy over gRPC basically. With these changes, rather than accessing request.Headers directly to access HTTP headers it's suggested that they use request.GetHTTPHeader() instead, e.g. request.GetHTTPHeader(backend.OAuthIdentityTokenHeaderName) or request.GetHTTPHeader(backend.OAuthIdentityIDTokenHeaderName), to access Forwarded OAuth Identity headers. Using request.GetHTTPHeader() or request.GetHTTPHeaders() would also automatically handle canonical/non-canonical form problems. --- backend/data.go | 49 +++++++- backend/data_test.go | 105 ++++++++++++++++++ backend/diagnostics.go | 55 ++++++++- backend/diagnostics_test.go | 105 ++++++++++++++++++ backend/http_headers.go | 91 +++++++++++++++ backend/http_headers_test.go | 210 +++++++++++++++++++++++++++++++++++ backend/resource.go | 88 +++++++++++++-- backend/resource_test.go | 105 ++++++++++++++++++ 8 files changed, 796 insertions(+), 12 deletions(-) create mode 100644 backend/data_test.go create mode 100644 backend/diagnostics_test.go create mode 100644 backend/http_headers.go create mode 100644 backend/http_headers_test.go create mode 100644 backend/resource_test.go diff --git a/backend/data.go b/backend/data.go index d747de17a..2421e6628 100644 --- a/backend/data.go +++ b/backend/data.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "net/http" "time" "github.com/grafana/grafana-plugin-sdk-go/data" @@ -37,9 +38,50 @@ func (fn QueryDataHandlerFunc) QueryData(ctx context.Context, req *QueryDataRequ // QueryDataRequest contains a single request which contains multiple queries. // It is the input type for a QueryData call. type QueryDataRequest struct { + // PluginContext the contextual information for the request. PluginContext PluginContext - Headers map[string]string - Queries []DataQuery + + // Headers the environment/metadata information for the request. + // + // To access forwarded HTTP headers please use + // GetHTTPHeaders or GetHTTPHeader. + Headers map[string]string + + // Queries the data queries for the request. + Queries []DataQuery +} + +// SetHTTPHeader sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +func (req *QueryDataRequest) SetHTTPHeader(key, value string) { + if req.Headers == nil { + req.Headers = map[string]string{} + } + + setHTTPHeaderInStringMap(req.Headers, key, value) +} + +// DeleteHTTPHeader deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (req *QueryDataRequest) DeleteHTTPHeader(key string) { + deleteHTTPHeaderInStringMap(req.Headers, key) +} + +// GetHTTPHeader gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. +func (req *QueryDataRequest) GetHTTPHeader(key string) string { + return req.GetHTTPHeaders().Get(key) +} + +// GetHTTPHeaders returns HTTP headers. +func (req *QueryDataRequest) GetHTTPHeaders() http.Header { + return getHTTPHeadersFromStringMap(req.Headers) } // DataQuery represents a single query as sent from the frontend. @@ -119,6 +161,7 @@ type DataResponse struct { Status Status } +// ErrDataResponse returns an error DataResponse given status and message. func ErrDataResponse(status Status, message string) DataResponse { return DataResponse{ Error: errors.New(message), @@ -149,3 +192,5 @@ type TimeRange struct { func (tr TimeRange) Duration() time.Duration { return tr.To.Sub(tr.From) } + +var _ ForwardHTTPHeaders = (*QueryDataRequest)(nil) diff --git a/backend/data_test.go b/backend/data_test.go new file mode 100644 index 000000000..140b26c19 --- /dev/null +++ b/backend/data_test.go @@ -0,0 +1,105 @@ +package backend + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestQueryDataRequest(t *testing.T) { + req := &QueryDataRequest{} + const customHeaderName = "X-Custom" + + t.Run("Legacy headers", func(t *testing.T) { + req.Headers = map[string]string{ + "Authorization": "a", + "X-ID-Token": "b", + "Cookie": "c", + customHeaderName: "d", + } + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Empty(t, headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Empty(t, req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader canonical form", func(t *testing.T) { + req.SetHTTPHeader(OAuthIdentityTokenHeaderName, "a") + req.SetHTTPHeader(OAuthIdentityIDTokenHeaderName, "b") + req.SetHTTPHeader(CookiesHeaderName, "c") + req.SetHTTPHeader(customHeaderName, "d") + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Equal(t, "d", headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Equal(t, "d", req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader non-canonical form", func(t *testing.T) { + req.SetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName), "a") + req.SetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName), "b") + req.SetHTTPHeader(strings.ToLower(CookiesHeaderName), "c") + req.SetHTTPHeader(strings.ToLower(customHeaderName), "d") + + t.Run("GetHTTPHeaders non-canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", headers.Get(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", headers.Get(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", headers.Get(strings.ToLower(customHeaderName))) + }) + + t.Run("GetHTTPHeader non-canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", req.GetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", req.GetHTTPHeader(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", req.GetHTTPHeader(strings.ToLower(customHeaderName))) + }) + + t.Run("DeleteHTTPHeader non-canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(CookiesHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(customHeaderName)) + require.Empty(t, req.Headers) + }) + }) +} diff --git a/backend/diagnostics.go b/backend/diagnostics.go index 5631aad12..49d48fe20 100644 --- a/backend/diagnostics.go +++ b/backend/diagnostics.go @@ -2,6 +2,7 @@ package backend import ( "context" + "net/http" "strconv" ) @@ -53,14 +54,58 @@ func (hs HealthStatus) String() string { // CheckHealthRequest contains the healthcheck request type CheckHealthRequest struct { + // PluginContext the contextual information for the request. PluginContext PluginContext - Headers map[string]string + + // Headers the environment/metadata information for the request. + // + // To access forwarded HTTP headers please use + // GetHTTPHeaders or GetHTTPHeader. + Headers map[string]string +} + +// SetHTTPHeader sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +func (req *CheckHealthRequest) SetHTTPHeader(key, value string) { + if req.Headers == nil { + req.Headers = map[string]string{} + } + + setHTTPHeaderInStringMap(req.Headers, key, value) +} + +// DeleteHTTPHeader deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (req *CheckHealthRequest) DeleteHTTPHeader(key string) { + deleteHTTPHeaderInStringMap(req.Headers, key) +} + +// GetHTTPHeader gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. +func (req *CheckHealthRequest) GetHTTPHeader(key string) string { + return req.GetHTTPHeaders().Get(key) +} + +// GetHTTPHeaders returns HTTP headers. +func (req *CheckHealthRequest) GetHTTPHeaders() http.Header { + return getHTTPHeadersFromStringMap(req.Headers) } // CheckHealthResult contains the healthcheck response type CheckHealthResult struct { - Status HealthStatus - Message string + // Status the HealthStatus of the healthcheck. + Status HealthStatus + + // Message the message of the healthcheck, if any. + Message string + + // JSONDetails the details of the healthcheck, if any, encoded as JSON bytes. JSONDetails []byte } @@ -82,10 +127,14 @@ func (fn CollectMetricsHandlerFunc) CollectMetrics(ctx context.Context, req *Col // CollectMetricsRequest contains the metrics request type CollectMetricsRequest struct { + // PluginContext the contextual information for the request. PluginContext PluginContext } // CollectMetricsResult collect metrics result. type CollectMetricsResult struct { + // PrometheusMetrics the Prometheus metrics encoded as bytes. PrometheusMetrics []byte } + +var _ ForwardHTTPHeaders = (*CheckHealthRequest)(nil) diff --git a/backend/diagnostics_test.go b/backend/diagnostics_test.go new file mode 100644 index 000000000..56b0057e3 --- /dev/null +++ b/backend/diagnostics_test.go @@ -0,0 +1,105 @@ +package backend + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCheckHealthRequest(t *testing.T) { + req := &CheckHealthRequest{} + const customHeaderName = "X-Custom" + + t.Run("Legacy headers", func(t *testing.T) { + req.Headers = map[string]string{ + "Authorization": "a", + "X-ID-Token": "b", + "Cookie": "c", + customHeaderName: "d", + } + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Empty(t, headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Empty(t, req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader canonical form", func(t *testing.T) { + req.SetHTTPHeader(OAuthIdentityTokenHeaderName, "a") + req.SetHTTPHeader(OAuthIdentityIDTokenHeaderName, "b") + req.SetHTTPHeader(CookiesHeaderName, "c") + req.SetHTTPHeader(customHeaderName, "d") + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Equal(t, "d", headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Equal(t, "d", req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader non-canonical form", func(t *testing.T) { + req.SetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName), "a") + req.SetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName), "b") + req.SetHTTPHeader(strings.ToLower(CookiesHeaderName), "c") + req.SetHTTPHeader(strings.ToLower(customHeaderName), "d") + + t.Run("GetHTTPHeaders non-canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", headers.Get(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", headers.Get(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", headers.Get(strings.ToLower(customHeaderName))) + }) + + t.Run("GetHTTPHeader non-canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", req.GetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", req.GetHTTPHeader(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", req.GetHTTPHeader(strings.ToLower(customHeaderName))) + }) + + t.Run("DeleteHTTPHeader non-canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(CookiesHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(customHeaderName)) + require.Empty(t, req.Headers) + }) + }) +} diff --git a/backend/http_headers.go b/backend/http_headers.go new file mode 100644 index 000000000..689d8ae59 --- /dev/null +++ b/backend/http_headers.go @@ -0,0 +1,91 @@ +package backend + +import ( + "fmt" + "net/http" + "net/textproto" + "strings" +) + +const ( + // OAuthIdentityTokenHeaderName the header name used for forwarding + // OAuth Identity access token. + OAuthIdentityTokenHeaderName = "Authorization" + + // OAuthIdentityIDTokenHeaderName the header name used for forwarding + // OAuth Identity ID token. + OAuthIdentityIDTokenHeaderName = "X-Id-Token" + + // CookiesHeaderName the header name used for forwarding + // cookies. + CookiesHeaderName = "Cookie" + + httpHeaderPrefix = "http_" +) + +// ForwardHTTPHeaders interface marking that forward of HTTP headers is supported. +type ForwardHTTPHeaders interface { + // SetHTTPHeader sets the header entries associated with key to the + // single element value. It replaces any existing values + // associated with key. The key is case insensitive; it is + // canonicalized by textproto.CanonicalMIMEHeaderKey. + SetHTTPHeader(key, value string) + + // DeleteHTTPHeader deletes the values associated with key. + // The key is case insensitive; it is canonicalized by + // CanonicalHeaderKey. + DeleteHTTPHeader(key string) + + // GetHTTPHeader gets the first value associated with the given key. If + // there are no values associated with the key, Get returns "". + // It is case insensitive; textproto.CanonicalMIMEHeaderKey is + // used to canonicalize the provided key. Get assumes that all + // keys are stored in canonical form. + GetHTTPHeader(key string) string + + // GetHTTPHeaders returns HTTP headers. + GetHTTPHeaders() http.Header +} + +func setHTTPHeaderInStringMap(headers map[string]string, key string, value string) { + if headers == nil { + headers = map[string]string{} + } + + headers[fmt.Sprintf("%s%s", httpHeaderPrefix, key)] = value +} + +func getHTTPHeadersFromStringMap(headers map[string]string) http.Header { + httpHeaders := http.Header{} + + for k, v := range headers { + if textproto.CanonicalMIMEHeaderKey(k) == OAuthIdentityTokenHeaderName { + httpHeaders.Set(k, v) + } + + if textproto.CanonicalMIMEHeaderKey(k) == OAuthIdentityIDTokenHeaderName { + httpHeaders.Set(k, v) + } + + if textproto.CanonicalMIMEHeaderKey(k) == CookiesHeaderName { + httpHeaders.Set(k, v) + } + + if strings.HasPrefix(k, httpHeaderPrefix) { + hKey := strings.TrimPrefix(k, httpHeaderPrefix) + httpHeaders.Set(hKey, v) + } + } + + return httpHeaders +} + +func deleteHTTPHeaderInStringMap(headers map[string]string, key string) { + for k := range headers { + if textproto.CanonicalMIMEHeaderKey(k) == textproto.CanonicalMIMEHeaderKey(key) || + textproto.CanonicalMIMEHeaderKey(k) == textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", httpHeaderPrefix, key)) { + delete(headers, k) + break + } + } +} diff --git a/backend/http_headers_test.go b/backend/http_headers_test.go new file mode 100644 index 000000000..7b4070a6c --- /dev/null +++ b/backend/http_headers_test.go @@ -0,0 +1,210 @@ +package backend + +import ( + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/stretchr/testify/require" +) + +func TestSetHTTPHeaderInStringMap(t *testing.T) { + tcs := []struct { + input map[string]string + expected map[string]string + }{ + { + expected: map[string]string{ + "": "", + "a": "", + }, + }, + { + input: map[string]string{ + "authorization": "a", + "x-id-token": "b", + "cookie": "c", + "x-custom": "d", + }, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "a", + "Authorization": "a", + "x-id-token": "b", + "X-Id-Token": "b", + "cookie": "c", + "Cookie": "c", + "x-custom": "d", + "X-Custom": "d", + }, + }, + { + input: map[string]string{ + "Authorization": "a", + "X-ID-Token": "b", + "Cookie": "c", + "X-Custom": "d", + }, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "a", + "Authorization": "a", + "x-id-token": "b", + "X-Id-Token": "b", + "cookie": "c", + "Cookie": "c", + "x-custom": "d", + "X-Custom": "d", + }, + }, + } + + for _, tc := range tcs { + headerMap := map[string]string{} + for k, v := range tc.input { + setHTTPHeaderInStringMap(headerMap, k, v) + } + headers := getHTTPHeadersFromStringMap(headerMap) + spew.Dump(headers) + + for k, v := range tc.expected { + require.Equal(t, v, headers.Get(k)) + } + } +} + +func TestGetHTTPHeadersFromStringMap(t *testing.T) { + tcs := []struct { + input map[string]string + expected map[string]string + }{ + { + expected: map[string]string{ + "": "", + "a": "", + }, + }, + { + input: map[string]string{ + "authorization": "a", + "x-id-token": "b", + "cookie": "c", + httpHeaderPrefix + "x-custom": "d", + }, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "a", + "Authorization": "a", + "x-id-token": "b", + "X-Id-Token": "b", + "cookie": "c", + "Cookie": "c", + "x-custom": "d", + "X-Custom": "d", + }, + }, + { + input: map[string]string{ + "Authorization": "a", + "X-ID-Token": "b", + "Cookie": "c", + httpHeaderPrefix + "X-Custom": "d", + }, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "a", + "Authorization": "a", + "x-id-token": "b", + "X-Id-Token": "b", + "cookie": "c", + "Cookie": "c", + "x-custom": "d", + "X-Custom": "d", + }, + }, + } + + for _, tc := range tcs { + headers := getHTTPHeadersFromStringMap(tc.input) + + for k, v := range tc.expected { + require.Equal(t, v, headers.Get(k)) + } + } +} + +func TestDeleteHTTPHeaderInStringMap(t *testing.T) { + tcs := []struct { + input map[string]string + deleteKeys []string + expected map[string]string + }{ + { + expected: map[string]string{ + "": "", + "a": "", + }, + }, + { + input: map[string]string{ + "authorization": "a", + "x-id-token": "b", + "cookie": "c", + httpHeaderPrefix + "x-custom": "d", + }, + deleteKeys: []string{"authorization", "x-id-token", "cookie", "x-custom"}, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "", + "Authorization": "", + "x-id-token": "", + "X-Id-Token": "", + "cookie": "", + "Cookie": "", + "x-custom": "", + "X-Custom": "", + }, + }, + { + input: map[string]string{ + "Authorization": "a", + "X-ID-Token": "b", + "Cookie": "c", + httpHeaderPrefix + "X-Custom": "d", + }, + deleteKeys: []string{"Authorization", "X-Id-Token", "Cookie", "X-Custom"}, + expected: map[string]string{ + "": "", + "a": "", + "authorization": "", + "Authorization": "", + "x-id-token": "", + "X-Id-Token": "", + "cookie": "", + "Cookie": "", + "x-custom": "", + "X-Custom": "", + }, + }, + } + + for _, tc := range tcs { + headerMap := make(map[string]string, len(tc.input)) + for k, v := range tc.input { + headerMap[k] = v + } + + for _, key := range tc.deleteKeys { + deleteHTTPHeaderInStringMap(headerMap, key) + } + headers := getHTTPHeadersFromStringMap(headerMap) + + for k, v := range tc.expected { + require.Equal(t, v, headers.Get(k)) + } + } +} diff --git a/backend/resource.go b/backend/resource.go index 248f91670..83198dcd5 100644 --- a/backend/resource.go +++ b/backend/resource.go @@ -2,23 +2,95 @@ package backend import ( "context" + "net/http" + "net/textproto" ) // CallResourceRequest represents a request for a resource call. type CallResourceRequest struct { + // PluginContext the contextual information for the request. PluginContext PluginContext - Path string - Method string - URL string - Headers map[string][]string - Body []byte + + // Path the forwarded HTTP path for the request. + Path string + + // Method the forwarded HTTP method for the request. + Method string + + // URL the forwarded HTTP URL for the request. + URL string + + // Headers the forwarded HTTP headers for the request, if any. + // + // Recommended to use GetHTTPHeaders or GetHTTPHeader + // since it automatically handles canonicalization of + // HTTP header keys. + Headers map[string][]string + + // Body the forwarded HTTP body for the request, if any. + Body []byte +} + +// SetHTTPHeader sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +func (req *CallResourceRequest) SetHTTPHeader(key, value string) { + if req.Headers == nil { + req.Headers = map[string][]string{} + } + + req.Headers[key] = []string{value} +} + +// DeleteHTTPHeader deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (req *CallResourceRequest) DeleteHTTPHeader(key string) { + if req.Headers == nil { + return + } + + for k := range req.Headers { + if textproto.CanonicalMIMEHeaderKey(k) == textproto.CanonicalMIMEHeaderKey(key) { + delete(req.Headers, k) + break + } + } +} + +// GetHTTPHeader gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. +func (req *CallResourceRequest) GetHTTPHeader(key string) string { + return req.GetHTTPHeaders().Get(key) +} + +// GetHTTPHeaders returns HTTP headers. +func (req *CallResourceRequest) GetHTTPHeaders() http.Header { + httpHeaders := http.Header{} + + for k, v := range req.Headers { + for _, strVal := range v { + httpHeaders.Add(k, strVal) + } + } + + return httpHeaders } // CallResourceResponse represents a response from a resource call. type CallResourceResponse struct { - Status int + // Status the HTTP response status. + Status int + + // Headers the HTTP response headers. Headers map[string][]string - Body []byte + + // Body the HTTP response body. + Body []byte } // CallResourceResponseSender is used for sending resource call responses. @@ -41,3 +113,5 @@ type CallResourceHandlerFunc func(ctx context.Context, req *CallResourceRequest, func (fn CallResourceHandlerFunc) CallResource(ctx context.Context, req *CallResourceRequest, sender CallResourceResponseSender) error { return fn(ctx, req, sender) } + +var _ ForwardHTTPHeaders = (*CallResourceRequest)(nil) diff --git a/backend/resource_test.go b/backend/resource_test.go new file mode 100644 index 000000000..5dfe2420e --- /dev/null +++ b/backend/resource_test.go @@ -0,0 +1,105 @@ +package backend + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCallResourceRequest(t *testing.T) { + req := &CallResourceRequest{} + const customHeaderName = "X-Custom" + + t.Run("Legacy headers", func(t *testing.T) { + req.Headers = map[string][]string{ + "Authorization": {"a"}, + "X-ID-Token": {"b"}, + "Cookie": {"c"}, + customHeaderName: {"d"}, + } + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Equal(t, "d", headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Equal(t, "d", req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader canonical form", func(t *testing.T) { + req.SetHTTPHeader(OAuthIdentityTokenHeaderName, "a") + req.SetHTTPHeader(OAuthIdentityIDTokenHeaderName, "b") + req.SetHTTPHeader(CookiesHeaderName, "c") + req.SetHTTPHeader(customHeaderName, "d") + + t.Run("GetHTTPHeaders canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", headers.Get(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", headers.Get(CookiesHeaderName)) + require.Equal(t, "d", headers.Get(customHeaderName)) + }) + + t.Run("GetHTTPHeader canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(OAuthIdentityTokenHeaderName)) + require.Equal(t, "b", req.GetHTTPHeader(OAuthIdentityIDTokenHeaderName)) + require.Equal(t, "c", req.GetHTTPHeader(CookiesHeaderName)) + require.Equal(t, "d", req.GetHTTPHeader(customHeaderName)) + }) + + t.Run("DeleteHTTPHeader canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(OAuthIdentityTokenHeaderName) + req.DeleteHTTPHeader(OAuthIdentityIDTokenHeaderName) + req.DeleteHTTPHeader(CookiesHeaderName) + req.DeleteHTTPHeader(customHeaderName) + require.Empty(t, req.Headers) + }) + }) + + t.Run("SetHTTPHeader non-canonical form", func(t *testing.T) { + req.SetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName), "a") + req.SetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName), "b") + req.SetHTTPHeader(strings.ToLower(CookiesHeaderName), "c") + req.SetHTTPHeader(strings.ToLower(customHeaderName), "d") + + t.Run("GetHTTPHeaders non-canonical form", func(t *testing.T) { + headers := req.GetHTTPHeaders() + require.Equal(t, "a", headers.Get(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", headers.Get(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", headers.Get(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", headers.Get(strings.ToLower(customHeaderName))) + }) + + t.Run("GetHTTPHeader non-canonical form", func(t *testing.T) { + require.Equal(t, "a", req.GetHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName))) + require.Equal(t, "b", req.GetHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName))) + require.Equal(t, "c", req.GetHTTPHeader(strings.ToLower(CookiesHeaderName))) + require.Equal(t, "d", req.GetHTTPHeader(strings.ToLower(customHeaderName))) + }) + + t.Run("DeleteHTTPHeader non-canonical form", func(t *testing.T) { + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(OAuthIdentityIDTokenHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(CookiesHeaderName)) + req.DeleteHTTPHeader(strings.ToLower(customHeaderName)) + require.Empty(t, req.Headers) + }) + }) +}