diff --git a/main.go b/main.go index 361ee43..7e11d11 100644 --- a/main.go +++ b/main.go @@ -1,11 +1,56 @@ package traefik_response_header_forward_plugin import ( + "bufio" + "bytes" "context" "fmt" + "io" + "net" "net/http" ) +var ( + _ interface { + http.ResponseWriter + http.Hijacker + } = &wrappedResponseWriter{} +) + +type wrappedResponseWriter struct { + rw http.ResponseWriter + buf *bytes.Buffer + code int +} + +func (w *wrappedResponseWriter) Header() http.Header { + return w.rw.Header() +} + +func (w *wrappedResponseWriter) Write(b []byte) (int, error) { + return w.buf.Write(b) +} + +func (w *wrappedResponseWriter) WriteHeader(code int) { + w.code = code +} + +func (w *wrappedResponseWriter) Flush() { + w.rw.WriteHeader(w.code) + io.Copy(w.rw, w.buf) +} + +func (w *wrappedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := w.rw.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("%T is not an http.Hijacker", w.rw) + } + + return hijacker.Hijack() +} + +// ======================================== + type RequestHeader struct { Name string `json:"name,omitempty"` } @@ -21,9 +66,9 @@ func CreateConfig() *Config { } type ResponseHeaderForward struct { - next http.Handler - name string - requestHeaders []RequestHeader + next http.Handler + name string + config *Config } func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { @@ -38,21 +83,28 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } } return &ResponseHeaderForward{ - next: next, - name: name, - requestHeaders: config.RequestHeaders, + next: next, + name: name, + config: config, }, nil } func (a *ResponseHeaderForward) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - a.next.ServeHTTP(rw, req) + resp := &wrappedResponseWriter{ + rw: rw, + buf: &bytes.Buffer{}, + } + + defer resp.Flush() - // for _, requestHeader := range a.requestHeaders { - // headerValue := req.Header.Get(requestHeader.Name) - // if headerValue == "" { - // continue - // } + a.next.ServeHTTP(resp, req) - // rw.Header().Set(requestHeader.Name, headerValue) - // } + for _, requestHeader := range a.config.RequestHeaders { + headerValue := req.Header.Get(requestHeader.Name) + if headerValue == "" { + continue + } + + resp.Header().Set(requestHeader.Name, headerValue) + } }