Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] use httputil Rewrite instead of Director #43367

Merged
merged 3 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/httplib/reverseproxy/reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ func New(opts ...Option) (*Forwarder, error) {
opt(fwd)
}

// Director is called by the ReverseProxy to modify the request.
fwd.Director = func(request *http.Request) {
modifyRequest(request)
// Rewrite is called by the ReverseProxy to modify the request.
fwd.Rewrite = func(request *httputil.ProxyRequest) {
modifyRequest(request.Out)
if fwd.headerRewriter != nil {
fwd.headerRewriter.Rewrite(request)
}
if !fwd.passHostHeader {
request.Host = request.URL.Host
request.Out.Host = request.Out.URL.Host
}
}

Expand Down
39 changes: 28 additions & 11 deletions lib/httplib/reverseproxy/rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ package reverseproxy
import (
"net"
"net/http"
"net/http/httputil"
"os"
"strings"
)

// Rewriter is an interface for rewriting http requests.
type Rewriter interface {
Rewrite(*http.Request)
Rewrite(req *httputil.ProxyRequest)
}

// NewHeaderRewriter creates a new HeaderRewriter.
Expand All @@ -46,29 +47,45 @@ type HeaderRewriter struct {
}

// Rewrite request headers.
func (rw *HeaderRewriter) Rewrite(req *http.Request) {
if !rw.TrustForwardHeader {
func (rw *HeaderRewriter) Rewrite(req *httputil.ProxyRequest) {
if rw.TrustForwardHeader {
// net/http/httputil.ReverseProxy will strip some forwarding
// headers from the outbound request when Rewrite is set, which
// is what we use. If we trust the forwarding headers ensure they
// are added back to the outbound request.
for _, h := range XHeaders {
req.Header.Del(h)
val := req.In.Header.Get(h)
if val == "" {
continue
}
req.Out.Header.Set(h, val)
}
} else {
// if we don't trust the forwarding headers, ensure all are removed
// as net/http/httputil.ReverseProxy won't remove all the forwarding
// headers we care about.
for _, h := range XHeaders {
req.Out.Header.Del(h)
}
}
outReq := req.Out

// Set X-Real-IP header if it is not set to the IP address of the client making the request.
maybeSetXRealIP(req)
maybeSetXRealIP(outReq)

// Set X-Forwarded-* headers if it is not set to the scheme of the request.
maybeSetForwarded(req)
maybeSetForwarded(outReq)

if xfPort := req.Header.Get(XForwardedPort); xfPort == "" {
req.Header.Set(XForwardedPort, forwardedPort(req))
if xfPort := outReq.Header.Get(XForwardedPort); xfPort == "" {
outReq.Header.Set(XForwardedPort, forwardedPort(outReq))
}

if xfHost := req.Header.Get(XForwardedHost); xfHost == "" && req.Host != "" {
req.Header.Set(XForwardedHost, req.Host)
if xfHost := outReq.Header.Get(XForwardedHost); xfHost == "" && outReq.Host != "" {
outReq.Header.Set(XForwardedHost, outReq.Host)
}

if rw.Hostname != "" {
req.Header.Set(XForwardedServer, rw.Hostname)
outReq.Header.Set(XForwardedServer, rw.Hostname)
}
}

Expand Down
17 changes: 15 additions & 2 deletions lib/httplib/reverseproxy/rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package reverseproxy
import (
"crypto/tls"
"net/http"
"net/http/httputil"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -156,8 +157,20 @@ func TestRewriter(t *testing.T) {
if test.tlsReq {
req.TLS = &tls.ConnectionState{}
}
rewriter.Rewrite(req)
require.Equal(t, test.expected, req.Header)

// replicate net/http/httputil.ReverseProxy stripping
// forwarding headers from the outbound request
outReq := req.Clone(req.Context())
outReq.Header.Del("Forwarded")
outReq.Header.Del(XForwardedFor)
outReq.Header.Del(XForwardedHost)
outReq.Header.Del(XForwardedProto)

rewriter.Rewrite(&httputil.ProxyRequest{
In: req,
Out: outReq,
})
require.Equal(t, test.expected, outReq.Header)
})
}
}
10 changes: 5 additions & 5 deletions lib/srv/app/common/header_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package common

import (
"net/http"
"net/http/httputil"

"github.com/gravitational/teleport/lib/httplib/reverseproxy"
)
Expand All @@ -44,14 +44,14 @@ func NewHeaderRewriter(delegates ...reverseproxy.Rewriter) *HeaderRewriter {

// Rewrite will delegate to the supplied delegates' rewrite functions and then inject
// its own headers.
func (hr *HeaderRewriter) Rewrite(req *http.Request) {
func (hr *HeaderRewriter) Rewrite(req *httputil.ProxyRequest) {
for _, delegate := range hr.delegates {
delegate.Rewrite(req)
}

if req.TLS != nil {
req.Header.Set(XForwardedSSL, sslOn)
if req.Out.TLS != nil {
req.Out.Header.Set(XForwardedSSL, sslOn)
} else {
req.Header.Set(XForwardedSSL, sslOff)
req.Out.Header.Set(XForwardedSSL, sslOff)
}
}
20 changes: 16 additions & 4 deletions lib/srv/app/common/header_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"crypto/tls"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"testing"

Expand Down Expand Up @@ -51,8 +52,8 @@ func newTestDelegate(header, value string) *testDelegate {
}
}

func (t *testDelegate) Rewrite(req *http.Request) {
req.Header.Set(t.header, t.value)
func (t *testDelegate) Rewrite(req *httputil.ProxyRequest) {
req.Out.Header.Set(t.header, t.value)
}

func TestHeaderRewriter(t *testing.T) {
Expand Down Expand Up @@ -128,10 +129,21 @@ func TestHeaderRewriter(t *testing.T) {
delegates = append(delegates, test.extraDelegates...)
hr := NewHeaderRewriter(delegates...)

hr.Rewrite(test.req)
// replicate net/http/httputil.ReverseProxy stripping
// forwarding headers from the outbound request
outReq := test.req.Clone(test.req.Context())
outReq.Header.Del("Forwarded")
outReq.Header.Del(reverseproxy.XForwardedFor)
outReq.Header.Del(reverseproxy.XForwardedHost)
outReq.Header.Del(reverseproxy.XForwardedProto)

hr.Rewrite(&httputil.ProxyRequest{
In: test.req,
Out: outReq,
})

for header, value := range test.expectedHeaders {
assert.Equal(t, test.req.Header.Get(header), value[0])
assert.Equal(t, outReq.Header.Get(header), value[0])
}
})
}
Expand Down
Loading