Skip to content

Commit

Permalink
Add support for duplicating headers to X-Original- prefix (#163)
Browse files Browse the repository at this point in the history
* Add support for duplicating headers to X-Original- prefix

* Set the header on the proxy request instead or origin request

* Fix: Duplicate headers unit test

---------

Co-authored-by: Erik Burton <[email protected]>
  • Loading branch information
chadbean and erikburt authored Jan 19, 2024
1 parent fd35f30 commit 5770ebf
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 22 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
.vscode
.idea/

# binary
aws-sigv4-proxy
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ When running the Proxy, the following flags can be used (none are required) :
| `unsigned-payload` | Boolean | Prevent signing of the payload" | `False` |
| `port` | String | Port to serve http on | `8080` |
| `strip` or `s` | String | Headers to strip from incoming request | None |
| `duplicate-headers` | String | Duplicate headers to an X-Original- prefix name | None |
| `role-arn` | String | Amazon Resource Name (ARN) of the role to assume | None |
| `name` | String | AWS Service to sign for | None |
| `sign-host` | String | Host to sign for | None |
Expand Down Expand Up @@ -91,6 +92,17 @@ docker run --rm -ti \
aws-sigv4-proxy -v -s Authorization
```

Running the service and preserving the original Authorization header as X-Original-Authorization (useful because Authorization header will be overwritten.)

```sh
docker run --rm -ti \
-v ~/.aws:/root/.aws \
-p 8080:8080 \
-e 'AWS_SDK_LOAD_CONFIG=true' \
-e 'AWS_PROFILE=<SOME PROFILE>' \
aws-sigv4-proxy -v --duplicate-headers Authorization
```

Running the service with Assume Role to use temporary credentials

```sh
Expand Down
23 changes: 13 additions & 10 deletions cmd/aws-sigv4-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ var (
logSinging = kingpin.Flag("log-signing-process", "Log sigv4 signing process").Bool()
port = kingpin.Flag("port", "Port to serve http on").Default(":8080").String()
strip = kingpin.Flag("strip", "Headers to strip from incoming request").Short('s').Strings()
duplicateHeaders = kingpin.Flag("duplicate-headers", "Duplicate headers to an X-Original- prefix name").Strings()
roleArn = kingpin.Flag("role-arn", "Amazon Resource Name (ARN) of the role to assume").String()
signingNameOverride = kingpin.Flag("name", "AWS Service to sign for").String()
signingHostOverride = kingpin.Flag("sign-host", "Host to sign for").String()
Expand Down Expand Up @@ -104,7 +105,7 @@ func main() {
} else {
credentials = session.Config.Credentials
}

signer := v4.NewSigner(credentials, func(s *v4.Signer) {
if shouldLogSigning() {
s.Logger = awsLoggerAdapter{}
Expand All @@ -119,20 +120,22 @@ func main() {
}

log.WithFields(log.Fields{"StripHeaders": *strip}).Infof("Stripping headers %s", *strip)
log.WithFields(log.Fields{"DuplicateHeaders": *duplicateHeaders}).Infof("Duplicating headers %s", *duplicateHeaders)
log.WithFields(log.Fields{"port": *port}).Infof("Listening on %s", *port)

log.Fatal(
http.ListenAndServe(*port, &handler.Handler{
ProxyClient: &handler.ProxyClient{
Signer: signer,
Client: client,
StripRequestHeaders: *strip,
SigningNameOverride: *signingNameOverride,
SigningHostOverride: *signingHostOverride,
HostOverride: *hostOverride,
RegionOverride: *regionOverride,
LogFailedRequest: *logFailedResponse,
SchemeOverride: *schemeOverride,
Signer: signer,
Client: client,
StripRequestHeaders: *strip,
DuplicateRequestHeaders: *duplicateHeaders,
SigningNameOverride: *signingNameOverride,
SigningHostOverride: *signingHostOverride,
HostOverride: *hostOverride,
RegionOverride: *regionOverride,
LogFailedRequest: *logFailedResponse,
SchemeOverride: *schemeOverride,
},
}),
)
Expand Down
35 changes: 25 additions & 10 deletions handler/proxy_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ type Client interface {

// ProxyClient implements the Client interface
type ProxyClient struct {
Signer *v4.Signer
Client Client
StripRequestHeaders []string
SigningNameOverride string
SigningHostOverride string
HostOverride string
RegionOverride string
LogFailedRequest bool
SchemeOverride string
Signer *v4.Signer
Client Client
StripRequestHeaders []string
DuplicateRequestHeaders []string
SigningNameOverride string
SigningHostOverride string
HostOverride string
RegionOverride string
LogFailedRequest bool
SchemeOverride string
}

func (p *ProxyClient) sign(req *http.Request, service *endpoints.ResolvedEndpoint) error {
Expand Down Expand Up @@ -144,7 +145,7 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) {
return nil, err
}

var reqChunked = chunked(req.TransferEncoding);
var reqChunked = chunked(req.TransferEncoding)

// Ignore ContentLength if "chunked" transfer-coding is used.
if !reqChunked && req.ContentLength >= 0 {
Expand Down Expand Up @@ -190,6 +191,20 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) {
req.Header.Del(header)
}

// Duplicate the header value for any headers specified into a new header
// with an "X-Original-" prefix.
for _, header := range p.DuplicateRequestHeaders {
headerValue := req.Header.Get(header)
if headerValue == "" {
log.WithField("DuplicateHeader", string(header)).Debug("Header empty, will not duplicate:")
continue
}

log.WithField("DuplicateHeader", string(header)).Debug("Duplicate Header to X-Original-* Prefix:")
newHeaderName := fmt.Sprintf("X-Original-%s", header)
proxyReq.Header.Set(newHeaderName, headerValue)
}

// Add origin headers after request is signed (no overwrite)
copyHeaderWithoutOverwrite(proxyReq.Header, req.Header)

Expand Down
65 changes: 63 additions & 2 deletions handler/proxy_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,60 @@ func TestProxyClient_Do(t *testing.T) {
},
},
},
{
name: "should duplicate specified headers with prefix",
request: &http.Request{
Method: "GET",
URL: &url.URL{},
Host: "execute-api.us-west-2.amazonaws.com",
Header: http.Header{
"Authorization": []string{"customValue"},
"User-Agent": []string{"customAgent"},
},
Body: nil,
},
proxyClient: &ProxyClient{
Signer: v4.NewSigner(credentials.NewCredentials(&mockProvider{})),
Client: &mockHTTPClient{},
DuplicateRequestHeaders: []string{"Authorization"},
},
want: &want{
resp: &http.Response{},
err: nil,
request: &http.Request{
Host: "execute-api.us-west-2.amazonaws.com",
Header: http.Header{
"X-Original-Authorization": []string{"customValue"},
"User-Agent": []string{"customAgent"},
},
},
},
},
{
name: "should not duplicate empty headers with prefix",
request: &http.Request{
Method: "GET",
URL: &url.URL{},
Host: "execute-api.us-west-2.amazonaws.com",
Body: nil,
},
proxyClient: &ProxyClient{
Signer: v4.NewSigner(credentials.NewCredentials(&mockProvider{})),
Client: &mockHTTPClient{},
DuplicateRequestHeaders: []string{"NonExistentHeader"},
},
want: &want{
resp: &http.Response{},
err: nil,
request: &http.Request{
Host: "execute-api.us-west-2.amazonaws.com",
Header: http.Header{
// Ensure headers are not present
"X-Original-NonExistentHeader": nil,
},
},
},
},
}

for _, tt := range tests {
Expand All @@ -376,15 +430,21 @@ func TestProxyClient_Do(t *testing.T) {
assert.Equal(t, tt.want.err, err)

proxyRequest := tt.proxyClient.Client.(*mockHTTPClient).Request

assert.True(t, verifyRequest(proxyRequest, tt.want.request))
if proxyRequest == nil {
return
}

// Ensure specific headers are propagated (or not in certain cases) to the proxy request
for kk, vv := range tt.want.request.Header {
assert.Equal(t, vv, proxyRequest.Header[kk])
}

// Ensure encoding is propagated to the proxy request.
assert.Equal(t, chunked(tt.request.TransferEncoding), chunked(proxyRequest.TransferEncoding));
assert.Equal(t, chunked(tt.request.TransferEncoding), chunked(proxyRequest.TransferEncoding))
if chunked(tt.request.TransferEncoding) {
assert.Equal(t, tt.request.TransferEncoding, proxyRequest.TransferEncoding);
assert.Equal(t, tt.request.TransferEncoding, proxyRequest.TransferEncoding)
} else {
// Ensure content length is propagated to the proxy request.
assert.Equal(t, tt.request.ContentLength, proxyRequest.ContentLength)
Expand All @@ -408,6 +468,7 @@ func TestProxyClient_Do(t *testing.T) {
assert.Equal(t, 0, len(ttBody))
}
}

})
}
}
Expand Down

0 comments on commit 5770ebf

Please sign in to comment.