From 5ea7c94104db2284400376ec5b6aaf1ee6895bca Mon Sep 17 00:00:00 2001 From: Romain Menke Date: Mon, 20 Nov 2023 16:53:17 +0100 Subject: [PATCH 1/9] fix http.Flusher and io.ReaderFrom implementation --- http/interceptor.go | 113 +++++++++------------------------------ http/interceptor_test.go | 48 +++++++++++++++++ 2 files changed, 73 insertions(+), 88 deletions(-) diff --git a/http/interceptor.go b/http/interceptor.go index 6b14aead3..016912cee 100644 --- a/http/interceptor.go +++ b/http/interceptor.go @@ -108,7 +108,22 @@ func (i *rwInterceptor) Header() http.Header { return i.w.Header() } -var _ http.ResponseWriter = (*rwInterceptor)(nil) +func (i *rwInterceptor) ReadFrom(r io.Reader) (n int64, err error) { + return io.Copy(i, r) +} + +func (i *rwInterceptor) Flush() { + // coraza middleware always needs to buffer the entire request, response cycle + // we can not flush early +} + +type responseWriter interface { + http.ResponseWriter + io.ReaderFrom + http.Flusher +} + +var _ responseWriter = (*rwInterceptor)(nil) // wrap wraps the interceptor into a response writer that also preserves // the http interfaces implemented by the original response writer to avoid @@ -168,110 +183,32 @@ func wrap(w http.ResponseWriter, r *http.Request, tx types.Transaction) ( var ( hijacker, isHijacker = i.w.(http.Hijacker) pusher, isPusher = i.w.(http.Pusher) - flusher, isFlusher = i.w.(http.Flusher) - reader, isReader = i.w.(io.ReaderFrom) ) switch { - case !isHijacker && !isPusher && !isFlusher && !isReader: + case !isHijacker && !isPusher: return struct { - http.ResponseWriter + responseWriter }{i}, responseProcessor - case !isHijacker && !isPusher && !isFlusher && isReader: - return struct { - http.ResponseWriter - io.ReaderFrom - }{i, reader}, responseProcessor - case !isHijacker && !isPusher && isFlusher && !isReader: - return struct { - http.ResponseWriter - http.Flusher - }{i, flusher}, responseProcessor - case !isHijacker && !isPusher && isFlusher && isReader: + case !isHijacker && isPusher: return struct { - http.ResponseWriter - http.Flusher - io.ReaderFrom - }{i, flusher, reader}, responseProcessor - case !isHijacker && isPusher && !isFlusher && !isReader: - return struct { - http.ResponseWriter + responseWriter http.Pusher }{i, pusher}, responseProcessor - case !isHijacker && isPusher && !isFlusher && isReader: - return struct { - http.ResponseWriter - http.Pusher - io.ReaderFrom - }{i, pusher, reader}, responseProcessor - case !isHijacker && isPusher && isFlusher && !isReader: - return struct { - http.ResponseWriter - http.Pusher - http.Flusher - }{i, pusher, flusher}, responseProcessor - case !isHijacker && isPusher && isFlusher && isReader: - return struct { - http.ResponseWriter - http.Pusher - http.Flusher - io.ReaderFrom - }{i, pusher, flusher, reader}, responseProcessor - case isHijacker && !isPusher && !isFlusher && !isReader: + case isHijacker && !isPusher: return struct { - http.ResponseWriter + responseWriter http.Hijacker }{i, hijacker}, responseProcessor - case isHijacker && !isPusher && !isFlusher && isReader: + case isHijacker && isPusher: return struct { - http.ResponseWriter - http.Hijacker - io.ReaderFrom - }{i, hijacker, reader}, responseProcessor - case isHijacker && !isPusher && isFlusher && !isReader: - return struct { - http.ResponseWriter - http.Hijacker - http.Flusher - }{i, hijacker, flusher}, responseProcessor - case isHijacker && !isPusher && isFlusher && isReader: - return struct { - http.ResponseWriter - http.Hijacker - http.Flusher - io.ReaderFrom - }{i, hijacker, flusher, reader}, responseProcessor - case isHijacker && isPusher && !isFlusher && !isReader: - return struct { - http.ResponseWriter + responseWriter http.Hijacker http.Pusher }{i, hijacker, pusher}, responseProcessor - case isHijacker && isPusher && !isFlusher && isReader: - return struct { - http.ResponseWriter - http.Hijacker - http.Pusher - io.ReaderFrom - }{i, hijacker, pusher, reader}, responseProcessor - case isHijacker && isPusher && isFlusher && !isReader: - return struct { - http.ResponseWriter - http.Hijacker - http.Pusher - http.Flusher - }{i, hijacker, pusher, flusher}, responseProcessor - case isHijacker && isPusher && isFlusher && isReader: - return struct { - http.ResponseWriter - http.Hijacker - http.Pusher - http.Flusher - io.ReaderFrom - }{i, hijacker, pusher, flusher, reader}, responseProcessor default: return struct { - http.ResponseWriter + responseWriter }{i}, responseProcessor } } diff --git a/http/interceptor_test.go b/http/interceptor_test.go index e8424705b..7d00161c2 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -8,6 +8,8 @@ package http import ( + "bytes" + "io" "net/http" "net/http/httptest" "testing" @@ -44,3 +46,49 @@ func TestWriteHeader(t *testing.T) { t.Errorf("unexpected status code, want %d, have %d", want, have) } } + +func TestFlush(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + rw, responseProcessor := wrap(res, req, tx) + rw.WriteHeader(204) + rw.(http.Flusher).Flush() + // although we called WriteHeader, status code should be applied until + // responseProcessor is called. + if unwanted, have := 204, res.Code; unwanted == have { + t.Errorf("unexpected status code %d", have) + } + + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if want, have := 204, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } +} + +func TestReadFrom(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + rw, _ := wrap(res, req, tx) + rw.WriteHeader(204) + rw.(io.ReaderFrom).ReadFrom(bytes.NewBuffer([]byte("hello world"))) + + if want, have := 204, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } +} From a2c910cfd74560b9309655a492720b47a3cc9208 Mon Sep 17 00:00:00 2001 From: Romain Menke Date: Mon, 20 Nov 2023 17:46:45 +0100 Subject: [PATCH 2/9] lint --- http/interceptor_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/http/interceptor_test.go b/http/interceptor_test.go index 7d00161c2..d523cb6c6 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -86,7 +86,10 @@ func TestReadFrom(t *testing.T) { res := httptest.NewRecorder() rw, _ := wrap(res, req, tx) rw.WriteHeader(204) - rw.(io.ReaderFrom).ReadFrom(bytes.NewBuffer([]byte("hello world"))) + _, err = rw.(io.ReaderFrom).ReadFrom(bytes.NewBuffer([]byte("hello world"))) + if err != nil { + t.Errorf("unexpected error: %v", err) + } if want, have := 204, res.Code; want != have { t.Errorf("unexpected status code, want %d, have %d", want, have) From 1a524e441af514c89cb9f1f407d5d273cc2356ff Mon Sep 17 00:00:00 2001 From: Romain Menke Date: Mon, 20 Nov 2023 19:05:48 +0100 Subject: [PATCH 3/9] increase test coverage for interceptor --- http/interceptor_test.go | 152 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 1 deletion(-) diff --git a/http/interceptor_test.go b/http/interceptor_test.go index d523cb6c6..8c8e726b3 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -47,6 +47,75 @@ func TestWriteHeader(t *testing.T) { } } +func TestWrite(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + + rw, responseProcessor := wrap(res, req, tx) + _, err = rw.Write([]byte("hello")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + _, err = rw.Write([]byte("world")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if want, have := 200, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } +} + +func TestWriteWithWriteHeader(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + + rw, responseProcessor := wrap(res, req, tx) + rw.WriteHeader(204) + // although we called WriteHeader, status code should be applied until + // responseProcessor is called. + if unwanted, have := 204, res.Code; unwanted == have { + t.Errorf("unexpected status code %d", have) + } + + _, err = rw.Write([]byte("hello")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + _, err = rw.Write([]byte("world")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if want, have := 204, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } +} + func TestFlush(t *testing.T) { waf, err := coraza.NewWAF(coraza.NewWAFConfig()) if err != nil { @@ -75,6 +144,14 @@ func TestFlush(t *testing.T) { } } +type testReaderFrom struct { + io.Writer +} + +func (x *testReaderFrom) ReadFrom(r io.Reader) (n int64, err error) { + return io.Copy(x, r) +} + func TestReadFrom(t *testing.T) { waf, err := coraza.NewWAF(coraza.NewWAFConfig()) if err != nil { @@ -84,13 +161,86 @@ func TestReadFrom(t *testing.T) { tx := waf.NewTransaction() req, _ := http.NewRequest("GET", "", nil) res := httptest.NewRecorder() - rw, _ := wrap(res, req, tx) + + type responseWriter interface { + http.ResponseWriter + http.Flusher + } + + resWithReaderFrom := struct { + responseWriter + io.ReaderFrom + }{ + res, + &testReaderFrom{res}, + } + + rw, responseProcessor := wrap(resWithReaderFrom, req, tx) rw.WriteHeader(204) + // although we called WriteHeader, status code should be applied until + // responseProcessor is called. + if unwanted, have := 204, res.Code; unwanted == have { + t.Errorf("unexpected status code %d", have) + } + _, err = rw.(io.ReaderFrom).ReadFrom(bytes.NewBuffer([]byte("hello world"))) if err != nil { t.Errorf("unexpected error: %v", err) } + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if want, have := 204, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } +} + +type testPusher struct{} + +func (x *testPusher) Push(target string, opts *http.PushOptions) error { + return nil +} + +func TestPusher(t *testing.T) { + waf, err := coraza.NewWAF(coraza.NewWAFConfig()) + if err != nil { + t.Fatal(err) + } + + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + + type responseWriter interface { + http.ResponseWriter + http.Flusher + } + + resWithPush := struct { + responseWriter + http.Pusher + }{ + res, + &testPusher{}, + } + + rw, responseProcessor := wrap(resWithPush, req, tx) + rw.WriteHeader(204) + rw.(http.Pusher).Push("http://example.com", nil) + // although we called WriteHeader, status code should be applied until + // responseProcessor is called. + if unwanted, have := 204, res.Code; unwanted == have { + t.Errorf("unexpected status code %d", have) + } + + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if want, have := 204, res.Code; want != have { t.Errorf("unexpected status code, want %d, have %d", want, have) } From 86f9979a04aea4fab43e7b6ac81774a330a98a9b Mon Sep 17 00:00:00 2001 From: Romain Menke Date: Mon, 20 Nov 2023 19:24:53 +0100 Subject: [PATCH 4/9] lint --- http/interceptor_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/http/interceptor_test.go b/http/interceptor_test.go index 8c8e726b3..38d9d5c0d 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -229,7 +229,11 @@ func TestPusher(t *testing.T) { rw, responseProcessor := wrap(resWithPush, req, tx) rw.WriteHeader(204) - rw.(http.Pusher).Push("http://example.com", nil) + err = rw.(http.Pusher).Push("http://example.com", nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + // although we called WriteHeader, status code should be applied until // responseProcessor is called. if unwanted, have := 204, res.Code; unwanted == have { From 82080873fe639baf017d802d03a01d82a83f1ec4 Mon Sep 17 00:00:00 2001 From: Romain Menke <11521496+romainmenke@users.noreply.github.com> Date: Mon, 20 Nov 2023 20:15:30 +0100 Subject: [PATCH 5/9] Update http/interceptor_test.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Carlos Chávez --- http/interceptor_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/http/interceptor_test.go b/http/interceptor_test.go index 38d9d5c0d..a23c12cb4 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -200,7 +200,7 @@ func TestReadFrom(t *testing.T) { type testPusher struct{} -func (x *testPusher) Push(target string, opts *http.PushOptions) error { +func (x *testPusher) Push(string, *http.PushOptions) error { return nil } From 38edc61925abd0bb52e345c873dc0feef4974ec7 Mon Sep 17 00:00:00 2001 From: Romain Menke Date: Mon, 20 Nov 2023 22:49:47 +0100 Subject: [PATCH 6/9] increase test coverage --- http/interceptor.go | 4 -- http/interceptor_test.go | 125 ++++++++++++++++++++++++++++----------- 2 files changed, 90 insertions(+), 39 deletions(-) diff --git a/http/interceptor.go b/http/interceptor.go index 016912cee..746622caa 100644 --- a/http/interceptor.go +++ b/http/interceptor.go @@ -186,10 +186,6 @@ func wrap(w http.ResponseWriter, r *http.Request, tx types.Transaction) ( ) switch { - case !isHijacker && !isPusher: - return struct { - responseWriter - }{i}, responseProcessor case !isHijacker && isPusher: return struct { responseWriter diff --git a/http/interceptor_test.go b/http/interceptor_test.go index 38d9d5c0d..46ec5408c 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -8,8 +8,10 @@ package http import ( + "bufio" "bytes" "io" + "net" "net/http" "net/http/httptest" "testing" @@ -204,7 +206,13 @@ func (x *testPusher) Push(target string, opts *http.PushOptions) error { return nil } -func TestPusher(t *testing.T) { +type testHijacker struct{} + +func (x *testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, nil +} + +func TestInterface(t *testing.T) { waf, err := coraza.NewWAF(coraza.NewWAFConfig()) if err != nil { t.Fatal(err) @@ -214,38 +222,85 @@ func TestPusher(t *testing.T) { req, _ := http.NewRequest("GET", "", nil) res := httptest.NewRecorder() - type responseWriter interface { - http.ResponseWriter - http.Flusher - } - - resWithPush := struct { - responseWriter - http.Pusher - }{ - res, - &testPusher{}, - } - - rw, responseProcessor := wrap(resWithPush, req, tx) - rw.WriteHeader(204) - err = rw.(http.Pusher).Push("http://example.com", nil) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - // although we called WriteHeader, status code should be applied until - // responseProcessor is called. - if unwanted, have := 204, res.Code; unwanted == have { - t.Errorf("unexpected status code %d", have) - } - - err = responseProcessor(tx, req) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - if want, have := 204, res.Code; want != have { - t.Errorf("unexpected status code, want %d, have %d", want, have) - } + t.Run("default", func(t *testing.T) { + rw, _ := wrap(struct { + http.ResponseWriter + }{ + res, + }, req, tx) + + _, ok := rw.(http.Pusher) + if ok { + t.Errorf("expected the wrapped ResponseWriter to not implement http.Pusher") + } + + _, ok = rw.(http.Hijacker) + if ok { + t.Errorf("expected the wrapped ResponseWriter to not implement http.Hijacker") + } + }) + + t.Run("http.Pusher", func(t *testing.T) { + rw, _ := wrap(struct { + http.ResponseWriter + http.Pusher + }{ + res, + &testPusher{}, + }, req, tx) + + _, ok := rw.(http.Pusher) + if !ok { + t.Errorf("expected the wrapped ResponseWriter to implement http.Pusher") + } + + _, ok = rw.(http.Hijacker) + if ok { + t.Errorf("expected the wrapped ResponseWriter to not implement http.Hijacker") + } + }) + + t.Run("http.Hijacker", func(t *testing.T) { + rw, _ := wrap(struct { + http.ResponseWriter + http.Hijacker + }{ + res, + &testHijacker{}, + }, req, tx) + + _, ok := rw.(http.Hijacker) + if !ok { + t.Errorf("expected the wrapped ResponseWriter to implement http.Hijacker") + } + + _, ok = rw.(http.Pusher) + if ok { + t.Errorf("expected the wrapped ResponseWriter to not implement http.Pusher") + } + }) + + t.Run("http.Hijacker and http.Pusher", func(t *testing.T) { + rw, _ := wrap(struct { + http.ResponseWriter + http.Flusher + http.Hijacker + http.Pusher + }{ + res, + res, + &testHijacker{}, + &testPusher{}, + }, req, tx) + + _, ok := rw.(http.Hijacker) + if !ok { + t.Errorf("expected the wrapped ResponseWriter to implement http.Hijacker") + } + + _, ok = rw.(http.Pusher) + if !ok { + t.Errorf("expected the wrapped ResponseWriter to implement http.Pusher") + } + }) } From d9f6704e4306c504f93d68828593116f1b199cda Mon Sep 17 00:00:00 2001 From: Romain Menke Date: Mon, 20 Nov 2023 23:09:45 +0100 Subject: [PATCH 7/9] better implementation for http.Flusher --- http/interceptor.go | 5 ++-- http/interceptor_test.go | 60 ++++++++++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/http/interceptor.go b/http/interceptor.go index 746622caa..b069119d4 100644 --- a/http/interceptor.go +++ b/http/interceptor.go @@ -113,8 +113,9 @@ func (i *rwInterceptor) ReadFrom(r io.Reader) (n int64, err error) { } func (i *rwInterceptor) Flush() { - // coraza middleware always needs to buffer the entire request, response cycle - // we can not flush early + if !i.wroteHeader { + i.WriteHeader(i.statusCode) + } } type responseWriter interface { diff --git a/http/interceptor_test.go b/http/interceptor_test.go index f1a125ad9..dc4c6a081 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -124,26 +124,50 @@ func TestFlush(t *testing.T) { t.Fatal(err) } - tx := waf.NewTransaction() - req, _ := http.NewRequest("GET", "", nil) - res := httptest.NewRecorder() - rw, responseProcessor := wrap(res, req, tx) - rw.WriteHeader(204) - rw.(http.Flusher).Flush() - // although we called WriteHeader, status code should be applied until - // responseProcessor is called. - if unwanted, have := 204, res.Code; unwanted == have { - t.Errorf("unexpected status code %d", have) - } + t.Run("WriteHeader before Flush", func(t *testing.T) { + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + rw, responseProcessor := wrap(res, req, tx) + rw.WriteHeader(204) + rw.(http.Flusher).Flush() + // although we called WriteHeader, status code should be applied until + // responseProcessor is called. + if unwanted, have := 204, res.Code; unwanted == have { + t.Errorf("unexpected status code %d", have) + } - err = responseProcessor(tx, req) - if err != nil { - t.Errorf("unexpected error: %v", err) - } + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } - if want, have := 204, res.Code; want != have { - t.Errorf("unexpected status code, want %d, have %d", want, have) - } + if want, have := 204, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } + }) + + t.Run("Flush before WriteHeader", func(t *testing.T) { + tx := waf.NewTransaction() + req, _ := http.NewRequest("GET", "", nil) + res := httptest.NewRecorder() + rw, responseProcessor := wrap(res, req, tx) + rw.(http.Flusher).Flush() + rw.WriteHeader(204) + + if want, have := 200, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } + + err = responseProcessor(tx, req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if want, have := 200, res.Code; want != have { + t.Errorf("unexpected status code, want %d, have %d", want, have) + } + }) } type testReaderFrom struct { From ff07fa44ac843ce37c4a6bea07ff26ababd08e6c Mon Sep 17 00:00:00 2001 From: Romain Menke Date: Mon, 20 Nov 2023 23:13:50 +0100 Subject: [PATCH 8/9] cleanup --- http/interceptor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/http/interceptor.go b/http/interceptor.go index b069119d4..a7b616ab5 100644 --- a/http/interceptor.go +++ b/http/interceptor.go @@ -114,7 +114,7 @@ func (i *rwInterceptor) ReadFrom(r io.Reader) (n int64, err error) { func (i *rwInterceptor) Flush() { if !i.wroteHeader { - i.WriteHeader(i.statusCode) + i.WriteHeader(http.StatusOK) } } From fd8a1736f019e0d37d40730e5bcec14757f148a8 Mon Sep 17 00:00:00 2001 From: Romain Menke Date: Mon, 20 Nov 2023 23:19:33 +0100 Subject: [PATCH 9/9] cleanup --- http/interceptor_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/http/interceptor_test.go b/http/interceptor_test.go index dc4c6a081..e4da8e700 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -307,11 +307,9 @@ func TestInterface(t *testing.T) { t.Run("http.Hijacker and http.Pusher", func(t *testing.T) { rw, _ := wrap(struct { http.ResponseWriter - http.Flusher http.Hijacker http.Pusher }{ - res, res, &testHijacker{}, &testPusher{},