Skip to content

Commit

Permalink
fix: headers leaked during interruptions at phase 3/4 (#1062)
Browse files Browse the repository at this point in the history
fix: headers returned during interruptions at phase 3/4
  • Loading branch information
M4tteoP authored May 9, 2024
1 parent cd681b2 commit c1d1ccb
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 56 deletions.
15 changes: 14 additions & 1 deletion http/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func (i *rwInterceptor) WriteHeader(statusCode int) {

i.statusCode = statusCode
if it := i.tx.ProcessResponseHeaders(statusCode, i.proto); it != nil {
i.cleanHeaders()
i.Header().Set("Content-Length", "0")
i.statusCode = obtainStatusCodeFromInterruptionOrDefault(it, i.statusCode)
i.flushWriteHeader()
Expand All @@ -65,6 +66,13 @@ func (i *rwInterceptor) flushWriteHeader() {
}
}

// cleanHeaders removes all headers from the response
func (i *rwInterceptor) cleanHeaders() {
for k := range i.w.Header() {
i.w.Header().Del(k)
}
}

// Write buffers the response body until the request body limit is reach or an
// interruption is triggered, this buffer is later used to analyse the body in
// the response processor.
Expand All @@ -88,7 +96,10 @@ func (i *rwInterceptor) Write(b []byte) (int, error) {
// to it, otherwise we just send it to the response writer.
it, n, err := i.tx.WriteResponseBody(b)
if it != nil {
i.overrideWriteHeader(it.Status)
// if there is an interruption we must clean the headers and override the status code
i.cleanHeaders()
i.Header().Set("Content-Length", "0")
i.overrideWriteHeader(obtainStatusCodeFromInterruptionOrDefault(it, i.statusCode))
// We only flush the status code after an interruption.
i.flushWriteHeader()
return 0, nil
Expand Down Expand Up @@ -153,6 +164,8 @@ func wrap(w http.ResponseWriter, r *http.Request, tx types.Transaction) (
i.flushWriteHeader()
return err
} else if it != nil {
// if there is an interruption we must clean the headers and override the status code
i.cleanHeaders()
i.Header().Set("Content-Length", "0")
i.overrideWriteHeader(obtainStatusCodeFromInterruptionOrDefault(it, i.statusCode))
i.flushWriteHeader()
Expand Down
1 change: 0 additions & 1 deletion http/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,5 @@ func obtainStatusCodeFromInterruptionOrDefault(it *types.Interruption, defaultSt

return statusCode
}

return defaultStatusCode
}
145 changes: 91 additions & 54 deletions http/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,42 +238,53 @@ type httpTest struct {
respBody string
expectedProto string
expectedStatus int
expectedRespHeadersKeys []string
expectedRespBody string
}

var expectedNoBlockingHeaders = []string{"Content-Type", "Content-Length", "Coraza-Middleware", "Date"}

// When an interruption occour, we are expecting that no response headers are sent back to the client.
var expectedBlockingHeaders = []string{"Content-Length", "Date"}

func TestHttpServer(t *testing.T) {
tests := map[string]httpTest{
"no blocking": {
reqURI: "/hello",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
reqURI: "/hello",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespHeadersKeys: expectedNoBlockingHeaders,
},
"no blocking HTTP/2": {
http2: true,
reqURI: "/hello",
expectedProto: "HTTP/2.0",
expectedStatus: 201,
http2: true,
reqURI: "/hello",
expectedProto: "HTTP/2.0",
expectedStatus: 201,
expectedRespHeadersKeys: expectedNoBlockingHeaders,
},
"args blocking": {
reqURI: "/hello?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
reqURI: "/hello?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
expectedRespHeadersKeys: expectedBlockingHeaders,
},
"request body blocking": {
reqURI: "/hello",
reqBody: "eval('cat /etc/passwd')",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
reqURI: "/hello",
reqBody: "eval('cat /etc/passwd')",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
expectedRespHeadersKeys: expectedBlockingHeaders,
},
"request body larger than limit (process partial)": {
reqURI: "/hello",
reqBody: "eval('cat /etc/passwd')",
echoReqBody: true,
// Coraza only sees eva, not eval
reqBodyLimit: 3,
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespBody: "eval('cat /etc/passwd')",
reqBodyLimit: 3,
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespHeadersKeys: expectedNoBlockingHeaders,
expectedRespBody: "eval('cat /etc/passwd')",
},
"request body larger than limit (reject)": {
reqURI: "/hello",
Expand All @@ -283,37 +294,43 @@ func TestHttpServer(t *testing.T) {
shouldRejectOnBodyLimit: true,
expectedProto: "HTTP/1.1",
expectedStatus: 413,
expectedRespHeadersKeys: expectedBlockingHeaders,
expectedRespBody: "",
},
"response headers blocking": {
reqURI: "/hello",
respHeaders: map[string]string{"foo": "bar"},
expectedProto: "HTTP/1.1",
expectedStatus: 401,
reqURI: "/hello",
respHeaders: map[string]string{"foo": "bar"},
expectedProto: "HTTP/1.1",
expectedStatus: 401,
expectedRespHeadersKeys: expectedBlockingHeaders,
},
"response body not blocking": {
reqURI: "/hello",
respBody: "true negative response body",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespBody: "true negative response body",
reqURI: "/hello",
respBody: "true negative response body",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespHeadersKeys: expectedNoBlockingHeaders,
expectedRespBody: "true negative response body",
},
"response body blocking": {
reqURI: "/hello",
respBody: "password=xxxx",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
expectedRespBody: "", // blocking at response body phase means returning it empty
reqURI: "/hello",
respBody: "password=xxxx",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
expectedRespBody: "", // blocking at response body phase means returning it empty
expectedRespHeadersKeys: expectedBlockingHeaders,
},
"allow": {
reqURI: "/allow_me",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
reqURI: "/allow_me",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
expectedRespHeadersKeys: expectedNoBlockingHeaders,
},
"deny passes over allow due to ordering": {
reqURI: "/allow_me?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
reqURI: "/allow_me?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 403,
expectedRespHeadersKeys: expectedBlockingHeaders,
},
}

Expand Down Expand Up @@ -357,26 +374,29 @@ func TestHttpServer(t *testing.T) {
func TestHttpServerWithRuleEngineOff(t *testing.T) {
tests := map[string]httpTest{
"no blocking true negative": {
reqURI: "/hello",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Hello!",
expectedRespBody: "Hello!",
reqURI: "/hello",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Hello!",
expectedRespHeadersKeys: expectedNoBlockingHeaders,
expectedRespBody: "Hello!",
},
"no blocking true positive header phase": {
reqURI: "/hello?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Downstream works!",
expectedRespBody: "Downstream works!",
reqURI: "/hello?id=0",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Downstream works!",
expectedRespHeadersKeys: expectedNoBlockingHeaders,
expectedRespBody: "Downstream works!",
},
"no blocking true positive body phase": {
reqURI: "/hello",
reqBody: "eval('cat /etc/passwd')",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Waf is Off!",
expectedRespBody: "Waf is Off!",
reqURI: "/hello",
reqBody: "eval('cat /etc/passwd')",
expectedProto: "HTTP/1.1",
expectedStatus: 201,
respBody: "Waf is Off!",
expectedRespHeadersKeys: expectedNoBlockingHeaders,
expectedRespBody: "Waf is Off!",
},
}
logger := debuglog.Default().
Expand Down Expand Up @@ -458,6 +478,10 @@ func runAgainstWAF(t *testing.T, tCase httpTest, waf coraza.WAF) {
t.Errorf("unexpected status code, want: %d, have: %d", want, have)
}

if !keysExistInMap(t, tCase.expectedRespHeadersKeys, res.Header) {
t.Errorf("unexpected response headers, expected keys: %v, headers: %v", tCase.expectedRespHeadersKeys, res.Header)
}

resBody, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("unexpected error when reading the response body: %v", err)
Expand All @@ -480,6 +504,19 @@ func runAgainstWAF(t *testing.T, tCase httpTest, waf coraza.WAF) {
}
}

func keysExistInMap(t *testing.T, keys []string, m map[string][]string) bool {
t.Helper()
if len(keys) != len(m) {
return false
}
for _, key := range keys {
if _, ok := m[key]; !ok {
return false
}
}
return true
}

func TestObtainStatusCodeFromInterruptionOrDefault(t *testing.T) {
tCases := map[string]struct {
interruptionCode int
Expand Down

0 comments on commit c1d1ccb

Please sign in to comment.