diff --git a/lib/httplib/reverseproxy/rewriter.go b/lib/httplib/reverseproxy/rewriter.go index 507d2eecc0543..2a23dc272308d 100644 --- a/lib/httplib/reverseproxy/rewriter.go +++ b/lib/httplib/reverseproxy/rewriter.go @@ -126,9 +126,11 @@ func maybeSetXRealIP(req *http.Request) { // maybeSetForwarded sets X-Forwarded-* headers if it is not set to the // scheme of the request. func maybeSetForwarded(req *http.Request) { - // We need to delete the value because httputil.ReverseProxy - // appends to the existing value. - req.Header.Del(XForwardedFor) + // Set X-Forwarded-For since net/http/httputil.ReverseProxy won't + // do this when Rewrite is set. + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + req.Header.Set(XForwardedFor, clientIP) + } if req.Header.Get(XForwardedProto) != "" { return diff --git a/lib/httplib/reverseproxy/rewriter_test.go b/lib/httplib/reverseproxy/rewriter_test.go index b1da6c7241213..0656708523c4e 100644 --- a/lib/httplib/reverseproxy/rewriter_test.go +++ b/lib/httplib/reverseproxy/rewriter_test.go @@ -100,6 +100,7 @@ func TestRewriter(t *testing.T) { hostReq: "teleport.dev:3543", remoteAddr: "1.2.3.4:1234", expected: http.Header{ + XForwardedFor: []string{"1.2.3.4"}, XForwardedHost: []string{"teleport.dev:3543"}, XForwardedPort: []string{"3543"}, XForwardedProto: []string{"https"}, @@ -117,6 +118,7 @@ func TestRewriter(t *testing.T) { hostReq: "teleport.dev:3543", remoteAddr: "1.2.3.4:1234", expected: http.Header{ + XForwardedFor: []string{"1.2.3.4"}, XForwardedHost: []string{"teleport.dev:3543"}, XForwardedPort: []string{"3543"}, XForwardedProto: []string{"http"}, @@ -133,6 +135,7 @@ func TestRewriter(t *testing.T) { hostReq: "teleport.dev", remoteAddr: "1.2.3.4:1234", expected: http.Header{ + XForwardedFor: []string{"1.2.3.4"}, XForwardedHost: []string{"teleport.dev"}, XForwardedPort: []string{"80"}, XForwardedProto: []string{"http"}, @@ -141,9 +144,11 @@ func TestRewriter(t *testing.T) { }, }, } + rewriter := NewHeaderRewriter() // set hostname to make sure it's the same in all tests. rewriter.Hostname = hostname + for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) {