From 5770ebf7f95b58bad15546a2d1827040d61ec611 Mon Sep 17 00:00:00 2001 From: Chad Bean Date: Fri, 19 Jan 2024 14:25:47 -0500 Subject: [PATCH] Add support for duplicating headers to X-Original- prefix (#163) * Add support for duplicating headers to X-Original- prefix * Set the header on the proxy request instead or origin request * Fix: Duplicate headers unit test --------- Co-authored-by: Erik Burton --- .gitignore | 3 ++ README.md | 12 +++++++ cmd/aws-sigv4-proxy/main.go | 23 +++++++------ handler/proxy_client.go | 35 +++++++++++++------ handler/proxy_client_test.go | 65 ++++++++++++++++++++++++++++++++++-- 5 files changed, 116 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index cc06c539..b98c49a1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ .vscode .idea/ + +# binary +aws-sigv4-proxy diff --git a/README.md b/README.md index 9de0815f..5f21dc1e 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ When running the Proxy, the following flags can be used (none are required) : | `unsigned-payload` | Boolean | Prevent signing of the payload" | `False` | | `port` | String | Port to serve http on | `8080` | | `strip` or `s` | String | Headers to strip from incoming request | None | +| `duplicate-headers` | String | Duplicate headers to an X-Original- prefix name | None | | `role-arn` | String | Amazon Resource Name (ARN) of the role to assume | None | | `name` | String | AWS Service to sign for | None | | `sign-host` | String | Host to sign for | None | @@ -91,6 +92,17 @@ docker run --rm -ti \ aws-sigv4-proxy -v -s Authorization ``` +Running the service and preserving the original Authorization header as X-Original-Authorization (useful because Authorization header will be overwritten.) + +```sh +docker run --rm -ti \ + -v ~/.aws:/root/.aws \ + -p 8080:8080 \ + -e 'AWS_SDK_LOAD_CONFIG=true' \ + -e 'AWS_PROFILE=' \ + aws-sigv4-proxy -v --duplicate-headers Authorization +``` + Running the service with Assume Role to use temporary credentials ```sh diff --git a/cmd/aws-sigv4-proxy/main.go b/cmd/aws-sigv4-proxy/main.go index 62615a87..c4817cde 100644 --- a/cmd/aws-sigv4-proxy/main.go +++ b/cmd/aws-sigv4-proxy/main.go @@ -40,6 +40,7 @@ var ( logSinging = kingpin.Flag("log-signing-process", "Log sigv4 signing process").Bool() port = kingpin.Flag("port", "Port to serve http on").Default(":8080").String() strip = kingpin.Flag("strip", "Headers to strip from incoming request").Short('s').Strings() + duplicateHeaders = kingpin.Flag("duplicate-headers", "Duplicate headers to an X-Original- prefix name").Strings() roleArn = kingpin.Flag("role-arn", "Amazon Resource Name (ARN) of the role to assume").String() signingNameOverride = kingpin.Flag("name", "AWS Service to sign for").String() signingHostOverride = kingpin.Flag("sign-host", "Host to sign for").String() @@ -104,7 +105,7 @@ func main() { } else { credentials = session.Config.Credentials } - + signer := v4.NewSigner(credentials, func(s *v4.Signer) { if shouldLogSigning() { s.Logger = awsLoggerAdapter{} @@ -119,20 +120,22 @@ func main() { } log.WithFields(log.Fields{"StripHeaders": *strip}).Infof("Stripping headers %s", *strip) + log.WithFields(log.Fields{"DuplicateHeaders": *duplicateHeaders}).Infof("Duplicating headers %s", *duplicateHeaders) log.WithFields(log.Fields{"port": *port}).Infof("Listening on %s", *port) log.Fatal( http.ListenAndServe(*port, &handler.Handler{ ProxyClient: &handler.ProxyClient{ - Signer: signer, - Client: client, - StripRequestHeaders: *strip, - SigningNameOverride: *signingNameOverride, - SigningHostOverride: *signingHostOverride, - HostOverride: *hostOverride, - RegionOverride: *regionOverride, - LogFailedRequest: *logFailedResponse, - SchemeOverride: *schemeOverride, + Signer: signer, + Client: client, + StripRequestHeaders: *strip, + DuplicateRequestHeaders: *duplicateHeaders, + SigningNameOverride: *signingNameOverride, + SigningHostOverride: *signingHostOverride, + HostOverride: *hostOverride, + RegionOverride: *regionOverride, + LogFailedRequest: *logFailedResponse, + SchemeOverride: *schemeOverride, }, }), ) diff --git a/handler/proxy_client.go b/handler/proxy_client.go index e9b66e4e..740d5df2 100644 --- a/handler/proxy_client.go +++ b/handler/proxy_client.go @@ -35,15 +35,16 @@ type Client interface { // ProxyClient implements the Client interface type ProxyClient struct { - Signer *v4.Signer - Client Client - StripRequestHeaders []string - SigningNameOverride string - SigningHostOverride string - HostOverride string - RegionOverride string - LogFailedRequest bool - SchemeOverride string + Signer *v4.Signer + Client Client + StripRequestHeaders []string + DuplicateRequestHeaders []string + SigningNameOverride string + SigningHostOverride string + HostOverride string + RegionOverride string + LogFailedRequest bool + SchemeOverride string } func (p *ProxyClient) sign(req *http.Request, service *endpoints.ResolvedEndpoint) error { @@ -144,7 +145,7 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) { return nil, err } - var reqChunked = chunked(req.TransferEncoding); + var reqChunked = chunked(req.TransferEncoding) // Ignore ContentLength if "chunked" transfer-coding is used. if !reqChunked && req.ContentLength >= 0 { @@ -190,6 +191,20 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) { req.Header.Del(header) } + // Duplicate the header value for any headers specified into a new header + // with an "X-Original-" prefix. + for _, header := range p.DuplicateRequestHeaders { + headerValue := req.Header.Get(header) + if headerValue == "" { + log.WithField("DuplicateHeader", string(header)).Debug("Header empty, will not duplicate:") + continue + } + + log.WithField("DuplicateHeader", string(header)).Debug("Duplicate Header to X-Original-* Prefix:") + newHeaderName := fmt.Sprintf("X-Original-%s", header) + proxyReq.Header.Set(newHeaderName, headerValue) + } + // Add origin headers after request is signed (no overwrite) copyHeaderWithoutOverwrite(proxyReq.Header, req.Header) diff --git a/handler/proxy_client_test.go b/handler/proxy_client_test.go index 8d047c93..edaf1e62 100644 --- a/handler/proxy_client_test.go +++ b/handler/proxy_client_test.go @@ -365,6 +365,60 @@ func TestProxyClient_Do(t *testing.T) { }, }, }, + { + name: "should duplicate specified headers with prefix", + request: &http.Request{ + Method: "GET", + URL: &url.URL{}, + Host: "execute-api.us-west-2.amazonaws.com", + Header: http.Header{ + "Authorization": []string{"customValue"}, + "User-Agent": []string{"customAgent"}, + }, + Body: nil, + }, + proxyClient: &ProxyClient{ + Signer: v4.NewSigner(credentials.NewCredentials(&mockProvider{})), + Client: &mockHTTPClient{}, + DuplicateRequestHeaders: []string{"Authorization"}, + }, + want: &want{ + resp: &http.Response{}, + err: nil, + request: &http.Request{ + Host: "execute-api.us-west-2.amazonaws.com", + Header: http.Header{ + "X-Original-Authorization": []string{"customValue"}, + "User-Agent": []string{"customAgent"}, + }, + }, + }, + }, + { + name: "should not duplicate empty headers with prefix", + request: &http.Request{ + Method: "GET", + URL: &url.URL{}, + Host: "execute-api.us-west-2.amazonaws.com", + Body: nil, + }, + proxyClient: &ProxyClient{ + Signer: v4.NewSigner(credentials.NewCredentials(&mockProvider{})), + Client: &mockHTTPClient{}, + DuplicateRequestHeaders: []string{"NonExistentHeader"}, + }, + want: &want{ + resp: &http.Response{}, + err: nil, + request: &http.Request{ + Host: "execute-api.us-west-2.amazonaws.com", + Header: http.Header{ + // Ensure headers are not present + "X-Original-NonExistentHeader": nil, + }, + }, + }, + }, } for _, tt := range tests { @@ -376,15 +430,21 @@ func TestProxyClient_Do(t *testing.T) { assert.Equal(t, tt.want.err, err) proxyRequest := tt.proxyClient.Client.(*mockHTTPClient).Request + assert.True(t, verifyRequest(proxyRequest, tt.want.request)) if proxyRequest == nil { return } + // Ensure specific headers are propagated (or not in certain cases) to the proxy request + for kk, vv := range tt.want.request.Header { + assert.Equal(t, vv, proxyRequest.Header[kk]) + } + // Ensure encoding is propagated to the proxy request. - assert.Equal(t, chunked(tt.request.TransferEncoding), chunked(proxyRequest.TransferEncoding)); + assert.Equal(t, chunked(tt.request.TransferEncoding), chunked(proxyRequest.TransferEncoding)) if chunked(tt.request.TransferEncoding) { - assert.Equal(t, tt.request.TransferEncoding, proxyRequest.TransferEncoding); + assert.Equal(t, tt.request.TransferEncoding, proxyRequest.TransferEncoding) } else { // Ensure content length is propagated to the proxy request. assert.Equal(t, tt.request.ContentLength, proxyRequest.ContentLength) @@ -408,6 +468,7 @@ func TestProxyClient_Do(t *testing.T) { assert.Equal(t, 0, len(ttBody)) } } + }) } }