diff --git a/http/interceptor.go b/http/interceptor.go index 6b14aead3..a7b616ab5 100644 --- a/http/interceptor.go +++ b/http/interceptor.go @@ -108,7 +108,23 @@ func (i *rwInterceptor) Header() http.Header { return i.w.Header() } -var _ http.ResponseWriter = (*rwInterceptor)(nil) +func (i *rwInterceptor) ReadFrom(r io.Reader) (n int64, err error) { + return io.Copy(i, r) +} + +func (i *rwInterceptor) Flush() { + if !i.wroteHeader { + i.WriteHeader(http.StatusOK) + } +} + +type responseWriter interface { + http.ResponseWriter + io.ReaderFrom + http.Flusher +} + +var _ responseWriter = (*rwInterceptor)(nil) // wrap wraps the interceptor into a response writer that also preserves // the http interfaces implemented by the original response writer to avoid @@ -168,110 +184,28 @@ func wrap(w http.ResponseWriter, r *http.Request, tx types.Transaction) ( var ( hijacker, isHijacker = i.w.(http.Hijacker) pusher, isPusher = i.w.(http.Pusher) - flusher, isFlusher = i.w.(http.Flusher) - reader, isReader = i.w.(io.ReaderFrom) ) switch { - case !isHijacker && !isPusher && !isFlusher && !isReader: + case !isHijacker && isPusher: return struct { - http.ResponseWriter - }{i}, responseProcessor - case !isHijacker && !isPusher && !isFlusher && isReader: - return struct { - http.ResponseWriter - io.ReaderFrom - }{i, reader}, responseProcessor - case !isHijacker && !isPusher && isFlusher && !isReader: - return struct { - http.ResponseWriter - http.Flusher - }{i, flusher}, responseProcessor - case !isHijacker && !isPusher && isFlusher && isReader: - return struct { - http.ResponseWriter - http.Flusher - io.ReaderFrom - }{i, flusher, reader}, responseProcessor - case !isHijacker && isPusher && !isFlusher && !isReader: - return struct { - http.ResponseWriter + responseWriter http.Pusher }{i, pusher}, responseProcessor - case !isHijacker && isPusher && !isFlusher && isReader: - return struct { - http.ResponseWriter - http.Pusher - io.ReaderFrom - }{i, pusher, reader}, responseProcessor - case !isHijacker && isPusher && isFlusher && !isReader: - return struct { - http.ResponseWriter - http.Pusher - http.Flusher - }{i, pusher, flusher}, responseProcessor - case !isHijacker && isPusher && isFlusher && isReader: + case isHijacker && !isPusher: return struct { - http.ResponseWriter - http.Pusher - http.Flusher - io.ReaderFrom - }{i, pusher, flusher, reader}, responseProcessor - case isHijacker && !isPusher && !isFlusher && !isReader: - return struct { - http.ResponseWriter + responseWriter http.Hijacker }{i, hijacker}, responseProcessor - case isHijacker && !isPusher && !isFlusher && isReader: + case isHijacker && isPusher: return struct { - http.ResponseWriter - http.Hijacker - io.ReaderFrom - }{i, hijacker, reader}, responseProcessor - case isHijacker && !isPusher && isFlusher && !isReader: - return struct { - http.ResponseWriter - http.Hijacker - http.Flusher - }{i, hijacker, flusher}, responseProcessor - case isHijacker && !isPusher && isFlusher && isReader: - return struct { - http.ResponseWriter - http.Hijacker - http.Flusher - io.ReaderFrom - }{i, hijacker, flusher, reader}, responseProcessor - case isHijacker && isPusher && !isFlusher && !isReader: - return struct { - http.ResponseWriter + responseWriter http.Hijacker http.Pusher }{i, hijacker, pusher}, responseProcessor - case isHijacker && isPusher && !isFlusher && isReader: - return struct { - http.ResponseWriter - http.Hijacker - http.Pusher - io.ReaderFrom - }{i, hijacker, pusher, reader}, responseProcessor - case isHijacker && isPusher && isFlusher && !isReader: - return struct { - http.ResponseWriter - http.Hijacker - http.Pusher - http.Flusher - }{i, hijacker, pusher, flusher}, responseProcessor - case isHijacker && isPusher && isFlusher && isReader: - return struct { - http.ResponseWriter - http.Hijacker - http.Pusher - http.Flusher - io.ReaderFrom - }{i, hijacker, pusher, flusher, reader}, responseProcessor default: return struct { - http.ResponseWriter + responseWriter }{i}, responseProcessor } } diff --git a/http/interceptor_test.go b/http/interceptor_test.go index e8424705b..e4da8e700 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -8,6 +8,10 @@ package http import ( + "bufio" + "bytes" + "io" + "net" "net/http" "net/http/httptest" "testing" @@ -44,3 +48,281 @@ func TestWriteHeader(t *testing.T) { t.Errorf("unexpected status code, want %d, have %d", want, have) } } + +func TestWrite(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + + rw, responseProcessor := wrap(res, req, tx) + _, err = rw.Write([]byte("hello")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + _, err = rw.Write([]byte("world")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if want, have := 200, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } +} + +func TestWriteWithWriteHeader(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + + rw, responseProcessor := wrap(res, req, tx) + rw.WriteHeader(204) + // although we called WriteHeader, status code should be applied until + // responseProcessor is called. + if unwanted, have := 204, res.Code; unwanted == have { + t.Errorf("unexpected status code %d", have) + } + + _, err = rw.Write([]byte("hello")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + _, err = rw.Write([]byte("world")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if want, have := 204, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } +} + +func TestFlush(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + t.Run("WriteHeader before Flush", func(t *testing.T) { + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + rw, responseProcessor := wrap(res, req, tx) + rw.WriteHeader(204) + rw.(http.Flusher).Flush() + // although we called WriteHeader, status code should be applied until + // responseProcessor is called. + if unwanted, have := 204, res.Code; unwanted == have { + t.Errorf("unexpected status code %d", have) + } + + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if want, have := 204, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } + }) + + t.Run("Flush before WriteHeader", func(t *testing.T) { + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + rw, responseProcessor := wrap(res, req, tx) + rw.(http.Flusher).Flush() + rw.WriteHeader(204) + + if want, have := 200, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } + + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if want, have := 200, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } + }) +} + +type testReaderFrom struct { + io.Writer +} + +func (x *testReaderFrom) ReadFrom(r io.Reader) (n int64, err error) { + return io.Copy(x, r) +} + +func TestReadFrom(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + + type responseWriter interface { + http.ResponseWriter + http.Flusher + } + + resWithReaderFrom := struct { + responseWriter + io.ReaderFrom + }{ + res, + &testReaderFrom{res}, + } + + rw, responseProcessor := wrap(resWithReaderFrom, req, tx) + rw.WriteHeader(204) + // although we called WriteHeader, status code should be applied until + // responseProcessor is called. + if unwanted, have := 204, res.Code; unwanted == have { + t.Errorf("unexpected status code %d", have) + } + + _, err = rw.(io.ReaderFrom).ReadFrom(bytes.NewBuffer([]byte("hello world"))) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if want, have := 204, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } +} + +type testPusher struct{} + +func (x *testPusher) Push(string, *http.PushOptions) error { + return nil +} + +type testHijacker struct{} + +func (x *testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, nil +} + +func TestInterface(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + + t.Run("default", func(t *testing.T) { + rw, _ := wrap(struct { + http.ResponseWriter + }{ + res, + }, req, tx) + + _, ok := rw.(http.Pusher) + if ok { + t.Errorf("expected the wrapped ResponseWriter to not implement http.Pusher") + } + + _, ok = rw.(http.Hijacker) + if ok { + t.Errorf("expected the wrapped ResponseWriter to not implement http.Hijacker") + } + }) + + t.Run("http.Pusher", func(t *testing.T) { + rw, _ := wrap(struct { + http.ResponseWriter + http.Pusher + }{ + res, + &testPusher{}, + }, req, tx) + + _, ok := rw.(http.Pusher) + if !ok { + t.Errorf("expected the wrapped ResponseWriter to implement http.Pusher") + } + + _, ok = rw.(http.Hijacker) + if ok { + t.Errorf("expected the wrapped ResponseWriter to not implement http.Hijacker") + } + }) + + t.Run("http.Hijacker", func(t *testing.T) { + rw, _ := wrap(struct { + http.ResponseWriter + http.Hijacker + }{ + res, + &testHijacker{}, + }, req, tx) + + _, ok := rw.(http.Hijacker) + if !ok { + t.Errorf("expected the wrapped ResponseWriter to implement http.Hijacker") + } + + _, ok = rw.(http.Pusher) + if ok { + t.Errorf("expected the wrapped ResponseWriter to not implement http.Pusher") + } + }) + + t.Run("http.Hijacker and http.Pusher", func(t *testing.T) { + rw, _ := wrap(struct { + http.ResponseWriter + http.Hijacker + http.Pusher + }{ + res, + &testHijacker{}, + &testPusher{}, + }, req, tx) + + _, ok := rw.(http.Hijacker) + if !ok { + t.Errorf("expected the wrapped ResponseWriter to implement http.Hijacker") + } + + _, ok = rw.(http.Pusher) + if !ok { + t.Errorf("expected the wrapped ResponseWriter to implement http.Pusher") + } + }) +}