From 5a56bbd76ed63447bd8a86d9c4f99112dd938a09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Ch=C3=A1vez?= Date: Tue, 21 Nov 2023 06:53:53 +0100 Subject: [PATCH] chore: deletes content temporary file on close. Fixes #922. --- internal/collections/map.go | 12 + internal/corazawaf/transaction.go | 26 ++- internal/corazawaf/transaction_test.go | 292 ++++++++++++++++--------- 3 files changed, 223 insertions(+), 107 deletions(-) diff --git a/internal/collections/map.go b/internal/collections/map.go index eca4dcd38..be17b6a2d 100644 --- a/internal/collections/map.go +++ b/internal/collections/map.go @@ -40,6 +40,18 @@ func (c *Map) Get(key string) []string { return values } +func (c *Map) Keys() []string { + if len(c.data) == 0 { + return nil + } + + var keys = make(([]string), 0, len(c.data)) + for k := range c.data { + keys = append(keys, k) + } + return keys +} + func (c *Map) FindRegex(key *regexp.Regexp) []types.MatchData { var result []types.MatchData for k, data := range c.data { diff --git a/internal/corazawaf/transaction.go b/internal/corazawaf/transaction.go index 3f0be3842..790275581 100644 --- a/internal/corazawaf/transaction.go +++ b/internal/corazawaf/transaction.go @@ -11,6 +11,7 @@ import ( "math" "mime" "net/url" + "os" "path/filepath" "strconv" "strings" @@ -24,6 +25,7 @@ import ( "github.com/corazawaf/coraza/v3/internal/collections" "github.com/corazawaf/coraza/v3/internal/corazarules" "github.com/corazawaf/coraza/v3/internal/corazatypes" + "github.com/corazawaf/coraza/v3/internal/environment" stringsutil "github.com/corazawaf/coraza/v3/internal/strings" urlutil "github.com/corazawaf/coraza/v3/internal/url" "github.com/corazawaf/coraza/v3/types" @@ -1448,12 +1450,23 @@ func (tx *Transaction) AuditLog() *auditlog.Log { func (tx *Transaction) Close() error { defer tx.WAF.txPool.Put(tx) tx.variables.reset() + var errs []error if err := tx.requestBodyBuffer.Reset(); err != nil { - errs = append(errs, err) + errs = append(errs, fmt.Errorf("reseting request body buffer: %v", err)) } if err := tx.responseBodyBuffer.Reset(); err != nil { - errs = append(errs, err) + errs = append(errs, fmt.Errorf("reseting response body buffer: %v", err)) + } + + if environment.HasAccessToFS { + for _, k := range tx.variables.filesTmpContent.Keys() { + for _, tmpContent := range tx.variables.filesTmpContent.Get(k) { + if err := os.Remove(tmpContent); err != nil { + errs = append(errs, fmt.Errorf("deleting content temporary file: %v", err)) + } + } + } } if tx.IsInterrupted() { @@ -1468,14 +1481,11 @@ func (tx *Transaction) Close() error { Msg("Transaction finished") } - switch { - case len(errs) == 0: + if len(errs) == 0 { return nil - case len(errs) == 1: - return fmt.Errorf("transaction close failed: %s", errs[0].Error()) - default: - return fmt.Errorf("transaction close failed:\n- %s\n- %s", errs[0].Error(), errs[1].Error()) } + + return fmt.Errorf("transaction close failed: %v", errors.Join(errs...)) } // String will return a string with the transaction debug information diff --git a/internal/corazawaf/transaction_test.go b/internal/corazawaf/transaction_test.go index 938aef632..241820f67 100644 --- a/internal/corazawaf/transaction_test.go +++ b/internal/corazawaf/transaction_test.go @@ -98,7 +98,7 @@ func TestTxMultipart(t *testing.T) { tx.RequestBodyLimit = 9999999 _, err := tx.ParseRequestReader(strings.NewReader(data)) if err != nil { - t.Error("Failed to parse multipart request: " + err.Error()) + t.Fatal("Failed to parse multipart request: " + err.Error()) } exp := map[string]string{ "%{args_post.text}": "test-value", @@ -108,6 +108,10 @@ func TestTxMultipart(t *testing.T) { } validateMacroExpansion(exp, tx, t) + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } } func TestTxResponse(t *testing.T) { @@ -215,7 +219,7 @@ func TestWriteRequestBody(t *testing.T) { for _, c := range chunks { if it, _, err = writeRequestBody(tx, c); err != nil { - t.Errorf("Failed to write body buffer: %s", err.Error()) + t.Fatalf("Failed to write body buffer: %s", err.Error()) } } @@ -235,11 +239,13 @@ func TestWriteRequestBody(t *testing.T) { val := tx.variables.argsPost.Get("some") if len(val) != 1 || val[0] != "result" { - t.Errorf("Failed to set urlencoded POST data with arguments: \"%s\"", strings.Join(val, "\", \"")) + t.Fatalf("Failed to set urlencoded POST data with arguments: \"%s\"", strings.Join(val, "\", \"")) } } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } }) } @@ -296,7 +302,9 @@ func TestWriteRequestBodyOnLimitReached(t *testing.T) { t.Fatalf("unexpected number of bytes written") } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } }) } }) @@ -343,7 +351,9 @@ func TestWriteRequestBodyIsNopWhenBodyIsNotAccesible(t *testing.T) { t.Fatalf("unexpected number of bytes written") } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } }) } }) @@ -354,12 +364,12 @@ func TestResponseHeader(t *testing.T) { tx := makeTransaction(t) tx.AddResponseHeader("content-type", "test") if tx.variables.responseContentType.Get() != "test" { - t.Error("invalid RESPONSE_CONTENT_TYPE after response headers") + t.Fatal("invalid RESPONSE_CONTENT_TYPE after response headers") } interruption := tx.ProcessResponseHeaders(200, "OK") if interruption != nil { - t.Error("unexpected interruption") + t.Fatal("unexpected interruption") } } @@ -368,12 +378,16 @@ func TestProcessRequestHeadersDoesNoEvaluationOnEngineOff(t *testing.T) { tx.RuleEngine = types.RuleEngineOff if !tx.IsRuleEngineOff() { - t.Error("expected Engine off") + t.Fatal("expected Engine off") } _ = tx.ProcessRequestHeaders() if tx.lastPhase != 0 { // 0 means no phases have been evaluated - t.Error("unexpected rule evaluation") + t.Fatal("unexpected rule evaluation") + } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -381,10 +395,13 @@ func TestProcessRequestBodyDoesNoEvaluationOnEngineOff(t *testing.T) { tx := NewWAF().NewTransaction() tx.RuleEngine = types.RuleEngineOff if _, err := tx.ProcessRequestBody(); err != nil { - t.Error("failed to process request body") + t.Fatal("failed to process request body") } if tx.lastPhase != 0 { - t.Error("unexpected rule evaluation") + t.Fatal("unexpected rule evaluation") + } + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -393,7 +410,7 @@ func TestProcessResponseHeadersDoesNoEvaluationOnEngineOff(t *testing.T) { tx.RuleEngine = types.RuleEngineOff _ = tx.ProcessResponseHeaders(200, "OK") if tx.lastPhase != 0 { - t.Error("unexpected rule evaluation") + t.Fatal("unexpected rule evaluation") } } @@ -401,10 +418,10 @@ func TestProcessResponseBodyDoesNoEvaluationOnEngineOff(t *testing.T) { tx := NewWAF().NewTransaction() tx.RuleEngine = types.RuleEngineOff if _, err := tx.ProcessResponseBody(); err != nil { - t.Error("Failed to process response body") + t.Fatal("Failed to process response body") } if tx.lastPhase != 0 { - t.Error("unexpected rule evaluation") + t.Fatal("unexpected rule evaluation") } } @@ -413,7 +430,10 @@ func TestProcessLoggingDoesNoEvaluationOnEngineOff(t *testing.T) { tx.RuleEngine = types.RuleEngineOff tx.ProcessLogging() if tx.lastPhase != 0 { - t.Error("unexpected rule evaluation") + t.Fatal("unexpected rule evaluation") + } + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -422,11 +442,11 @@ func TestAuditLog(t *testing.T) { tx.AuditLogParts = types.AuditLogParts("ABCDEFGHIJK") al := tx.AuditLog() if al.Transaction().ID() != tx.id { - t.Error("invalid auditlog id") + t.Fatal("invalid auditlog id") } // TODO more checks if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -509,7 +529,7 @@ func TestWriteResponseBody(t *testing.T) { for _, c := range chunks { if it, _, err = writeResponseBody(tx, c); err != nil { - t.Errorf("Failed to write body buffer: %s", err.Error()) + t.Fatalf("Failed to write body buffer: %s", err.Error()) } } @@ -529,11 +549,13 @@ func TestWriteResponseBody(t *testing.T) { // checking if the body has been populated up to the first POST arg index := strings.Index(urlencodedBody, "&") if tx.variables.responseBody.Get()[:index] != urlencodedBody[:index] { - t.Error("failed to set response body") + t.Fatal("failed to set response body") } } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } }) } @@ -590,7 +612,9 @@ func TestWriteResponseBodyOnLimitReached(t *testing.T) { t.Fatalf("unexpected number of bytes written") } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } }) } }) @@ -637,7 +661,9 @@ func TestWriteResponseBodyIsNopWhenBodyIsNotAccesible(t *testing.T) { t.Fatalf("unexpected number of bytes written") } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } }) } }) @@ -658,21 +684,21 @@ func TestAuditLogFields(t *testing.T) { }, }) if len(tx.matchedRules) == 0 || tx.matchedRules[0].Rule().ID() != rule.ID_ { - t.Error("failed to match rule for audit") + t.Fatal("failed to match rule for audit") } al := tx.AuditLog() if len(al.Messages()) == 0 || al.Messages()[0].Data().ID() != rule.ID_ { - t.Error("failed to add rules to audit logs") + t.Fatal("failed to add rules to audit logs") } if len(al.Transaction().Request().Headers()) == 0 || al.Transaction().Request().Headers()["test"][0] != "test" { - t.Error("failed to add request header to audit log") + t.Fatal("failed to add request header to audit log") } if len(al.Transaction().Response().Headers()) == 0 || al.Transaction().Response().Headers()["test"][0] != "test" { - t.Error("failed to add Response header to audit log") + t.Fatal("failed to add Response header to audit log") } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -681,14 +707,14 @@ func TestResetCapture(t *testing.T) { tx.Capture = true tx.CaptureField(5, "test") if tx.variables.tx.Get("5")[0] != "test" { - t.Error("failed to set capture field from tx") + t.Fatal("failed to set capture field from tx") } tx.resetCaptures() if tx.variables.tx.Get("5")[0] != "" { - t.Error("failed to reset capture field from tx") + t.Fatal("failed to reset capture field from tx") } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -701,7 +727,7 @@ func TestRelevantAuditLogging(t *testing.T) { tx.ProcessLogging() // TODO how do we check if the log was writen? if err := tx.Close(); err != nil { - t.Error(err) + t.Fatal(err) } } @@ -770,13 +796,13 @@ func TestLogCallback(t *testing.T) { } if buffer == "" || !strings.Contains(buffer, tx.id) { - t.Error("failed to call error log callback") + t.Fatal("failed to call error log callback") } if !strings.Contains(buffer, testCase.expectedLogLine) { - t.Errorf("Expected string \"%s\" with disruptive rule, got %s", testCase.expectedLogLine, buffer) + t.Fatalf("Expected string \"%s\" with disruptive rule, got %s", testCase.expectedLogLine, buffer) if err := tx.Close(); err != nil { - t.Error(err) + t.Fatal(err) } } }) @@ -790,19 +816,19 @@ func TestHeaderSetters(t *testing.T) { tx.AddRequestHeader("test1", "test2") c := tx.variables.requestCookies.Get("abc")[0] if c != "def" { - t.Errorf("failed to set cookie, got %q", c) + t.Fatalf("failed to set cookie, got %q", c) } if tx.variables.requestHeaders.Get("cookie")[0] != "abc=def;hij=klm" { - t.Error("failed to set request header") + t.Fatal("failed to set request header") } if !utils.InSlice("cookie", collectionValues(t, tx.variables.requestHeadersNames)) { - t.Error("failed to set header name", collectionValues(t, tx.variables.requestHeadersNames)) + t.Fatal("failed to set header name", collectionValues(t, tx.variables.requestHeadersNames)) } if !utils.InSlice("abc", collectionValues(t, tx.variables.requestCookiesNames)) { - t.Error("failed to set cookie name") + t.Fatal("failed to set cookie name") } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -825,16 +851,16 @@ func TestRequestBodyProcessingAlgorithm(t *testing.T) { tx.AddRequestHeader("content-length", "7") tx.ProcessRequestHeaders() if _, err := tx.requestBodyBuffer.Write([]byte("test123")); err != nil { - t.Error("Failed to write request body buffer") + t.Fatal("Failed to write request body buffer") } if _, err := tx.ProcessRequestBody(); err != nil { - t.Error("failed to process request body") + t.Fatal("failed to process request body") } if tx.variables.requestBody.Get() != "test123" { - t.Error("failed to set request body") + t.Fatal("failed to set request body") } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -876,7 +902,7 @@ func TestProcessBodiesSkippedIfHeadersPhasesNotReached(t *testing.T) { t.Fatalf("unexpected message, want %q, have %q", want, have) } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -888,31 +914,31 @@ func TestTxVariables(t *testing.T) { KeyRx: regexp.MustCompile("ho.*"), } if len(tx.GetField(rv)) != 1 || tx.GetField(rv)[0].Value() != "www.test.com:80" { - t.Errorf("failed to match rule variable REQUEST_HEADERS:host, %d matches, %v", len(tx.GetField(rv)), tx.GetField(rv)) + t.Fatalf("failed to match rule variable REQUEST_HEADERS:host, %d matches, %v", len(tx.GetField(rv)), tx.GetField(rv)) } rv.Count = true if len(tx.GetField(rv)) == 0 || tx.GetField(rv)[0].Value() != "1" { - t.Errorf("failed to get count for regexp variable") + t.Fatalf("failed to get count for regexp variable") } // now nil key rv.KeyRx = nil if len(tx.GetField(rv)) == 0 { - t.Error("failed to match rule variable REQUEST_HEADERS with nil key") + t.Fatal("failed to match rule variable REQUEST_HEADERS with nil key") } rv.KeyStr = "" f := tx.GetField(rv) if len(f) == 0 { - t.Error("failed to count variable REQUEST_HEADERS ") + t.Fatal("failed to count variable REQUEST_HEADERS ") } count, err := strconv.Atoi(f[0].Value()) if err != nil { - t.Error(err) + t.Fatal(err) } if count != 5 { - t.Errorf("failed to match rule variable REQUEST_HEADERS with count, %v", rv) + t.Fatalf("failed to match rule variable REQUEST_HEADERS with count, %v", rv) } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatal(err) } } @@ -928,12 +954,12 @@ func TestTxVariablesExceptions(t *testing.T) { } fields := tx.GetField(rv) if len(fields) != 0 { - t.Errorf("REQUEST_HEADERS:host should not match, got %d matches, %v", len(fields), fields) + t.Fatalf("REQUEST_HEADERS:host should not match, got %d matches, %v", len(fields), fields) } rv.Exceptions = nil fields = tx.GetField(rv) if len(fields) != 1 || fields[0].Value() != "www.test.com:80" { - t.Errorf("failed to match rule variable REQUEST_HEADERS:host, %d matches, %v", len(fields), fields) + t.Fatalf("failed to match rule variable REQUEST_HEADERS:host, %d matches, %v", len(fields), fields) } rv.Exceptions = []ruleVariableException{ { @@ -942,10 +968,10 @@ func TestTxVariablesExceptions(t *testing.T) { } fields = tx.GetField(rv) if len(fields) != 0 { - t.Errorf("REQUEST_HEADERS:host should not match, got %d matches, %v", len(fields), fields) + t.Fatalf("REQUEST_HEADERS:host should not match, got %d matches, %v", len(fields), fields) } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatal(err) } } @@ -959,11 +985,11 @@ func TestTransactionSyncPool(t *testing.T) { }) for i := 0; i < 1000; i++ { if err := tx.Close(); err != nil { - t.Error(err) + t.Fatal(err) } tx = waf.NewTransaction() if len(tx.matchedRules) != 0 { - t.Errorf("failed to sync transaction pool, %d rules found after %d attempts", len(tx.matchedRules), i+1) + t.Fatalf("failed to sync transaction pool, %d rules found after %d attempts", len(tx.matchedRules), i+1) return } } @@ -981,16 +1007,16 @@ func TestTxPhase4Magic(t *testing.T) { _, _ = tx.ProcessRequestBody() tx.ProcessResponseHeaders(200, "HTTP/1.1") if it, _, err := tx.WriteResponseBody([]byte("more bytes")); it != nil || err != nil { - t.Error(err) + t.Fatal(err) } if _, err := tx.ProcessResponseBody(); err != nil { - t.Error(err) + t.Fatal(err) } if tx.variables.outboundDataError.Get() != "1" { - t.Error("failed to set outbound data error") + t.Fatal("failed to set outbound data error") } if tx.variables.responseBody.Get() != "mor" { - t.Error("failed to set response body") + t.Fatal("failed to set response body") } } @@ -1009,12 +1035,16 @@ func TestVariablesMatch(t *testing.T) { for k, v := range expect { if m := (tx.Collection(k)).(*collections.Single).Get(); m != v { - t.Errorf("failed to match variable %s, Expected: %s, got: %s", k.Name(), v, m) + t.Fatalf("failed to match variable %s, Expected: %s, got: %s", k.Name(), v, m) } } if len(tx.variables.matchedVars.Get("ARGS_NAMES:sample")) == 0 { - t.Errorf("failed to match variable %s, got 0", variables.MatchedVars.Name()) + t.Fatalf("failed to match variable %s, got 0", variables.MatchedVars.Name()) + } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -1025,13 +1055,17 @@ func TestTxReqBodyForce(t *testing.T) { tx.RequestBodyAccess = true tx.ForceRequestBodyVariable = true if _, err := tx.requestBodyBuffer.Write([]byte("test")); err != nil { - t.Error(err) + t.Fatal(err) } if _, err := tx.ProcessRequestBody(); err != nil { - t.Error(err) + t.Fatal(err) } if tx.variables.requestBody.Get() != "test" { - t.Error("failed to set request body") + t.Fatal("failed to set request body") + } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -1041,13 +1075,17 @@ func TestTxReqBodyForceNegative(t *testing.T) { tx.RequestBodyAccess = true tx.ForceRequestBodyVariable = false if _, err := tx.requestBodyBuffer.Write([]byte("test")); err != nil { - t.Error(err) + t.Fatal(err) } if _, err := tx.ProcessRequestBody(); err != nil { - t.Error(err) + t.Fatal(err) } if tx.variables.requestBody.Get() == "test" { - t.Error("reqbody should not be there") + t.Fatal("reqbody should not be there") + } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -1056,10 +1094,14 @@ func TestTxProcessConnection(t *testing.T) { tx := waf.NewTransaction() tx.ProcessConnection("127.0.0.1", 80, "127.0.0.2", 8080) if tx.variables.remoteAddr.Get() != "127.0.0.1" { - t.Error("failed to set client ip") + t.Fatal("failed to set client ip") } if rp, _ := strconv.Atoi(tx.variables.remotePort.Get()); rp != 80 { - t.Error("failed to set client port") + t.Fatal("failed to set client port") + } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -1074,7 +1116,7 @@ func TestTxSetServerName(t *testing.T) { tx.lastPhase = types.PhaseRequestHeaders tx.SetServerName("coraza.io") if tx.variables.serverName.Get() != "coraza.io" { - t.Error("failed to set server name") + t.Fatal("failed to set server name") } logEntries := strings.Split(strings.TrimSpace(logBuffer.String()), "\n") if want, have := 1, len(logEntries); want != have { @@ -1084,6 +1126,10 @@ func TestTxSetServerName(t *testing.T) { if want, have := "SetServerName has been called after ProcessRequestHeaders", logEntries[0]; !strings.Contains(have, want) { t.Fatalf("unexpected message, want %q, have %q", want, have) } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } } func TestTxAddArgument(t *testing.T) { @@ -1092,15 +1138,19 @@ func TestTxAddArgument(t *testing.T) { tx.ProcessConnection("127.0.0.1", 80, "127.0.0.2", 8080) tx.AddGetRequestArgument("test", "testvalue") if tx.variables.argsGet.Get("test")[0] != "testvalue" { - t.Error("failed to set args get") + t.Fatal("failed to set args get") } tx.AddPostRequestArgument("ptest", "ptestvalue") if tx.variables.argsPost.Get("ptest")[0] != "ptestvalue" { - t.Error("failed to set args post") + t.Fatal("failed to set args post") } tx.AddPathRequestArgument("ptest2", "ptestvalue") if tx.variables.argsPath.Get("ptest2")[0] != "ptestvalue" { - t.Error("failed to set args post") + t.Fatal("failed to set args post") + } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -1110,7 +1160,11 @@ func TestTxGetField(t *testing.T) { Variable: variables.Args, } if f := tx.GetField(rvp); len(f) != 3 { - t.Errorf("failed to get field, expected 2, got %d", len(f)) + t.Fatalf("failed to get field, expected 2, got %d", len(f)) + } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -1120,19 +1174,23 @@ func TestTxProcessURI(t *testing.T) { uri := "http://example.com/path/to/file.html?query=string&other=value" tx.ProcessURI(uri, "GET", "HTTP/1.1") if s := tx.variables.requestURI.Get(); s != uri { - t.Errorf("failed to set request uri, got %s", s) + t.Fatalf("failed to set request uri, got %s", s) } if s := tx.variables.requestBasename.Get(); s != "file.html" { - t.Errorf("failed to set request path, got %s", s) + t.Fatalf("failed to set request path, got %s", s) } if tx.variables.queryString.Get() != "query=string&other=value" { - t.Error("failed to set request query") + t.Fatal("failed to set request query") } if v := tx.variables.args.FindAll(); len(v) != 2 { - t.Errorf("failed to set request args, got %d", len(v)) + t.Fatalf("failed to set request args, got %d", len(v)) } if v := tx.variables.args.FindString("other"); v[0].Value() != "value" { - t.Errorf("failed to set request args, got %v", v) + t.Fatalf("failed to set request args, got %v", v) + } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } } @@ -1208,7 +1266,7 @@ func validateMacroExpansion(tests map[string]string, tx *Transaction, t *testing for k, v := range tests { m, err := macro.NewMacro(k) if err != nil { - t.Error(err) + t.Fatal(err) } res := m.Expand(tx) if res != v { @@ -1216,7 +1274,7 @@ func validateMacroExpansion(tests map[string]string, tx *Transaction, t *testing fmt.Println(tx) fmt.Println("===STACK===\n", string(debug.Stack())+"\n===STACK===") } - t.Error("Failed set transaction for " + k + ", expected " + v + ", got " + res) + t.Fatal("Failed set transaction for " + k + ", expected " + v + ", got " + res) } } } @@ -1226,28 +1284,32 @@ func TestMacro(t *testing.T) { tx.variables.tx.Set("some", []string{"secretly"}) m, err := macro.NewMacro("%{unique_id}") if err != nil { - t.Error(err) + t.Fatal(err) } if m.Expand(tx) != tx.id { - t.Errorf("%s != %s", m.Expand(tx), tx.id) + t.Fatalf("%s != %s", m.Expand(tx), tx.id) } m, err = macro.NewMacro("some complex text %{tx.some} wrapped in m") if err != nil { - t.Error(err) + t.Fatal(err) } if m.Expand(tx) != "some complex text secretly wrapped in m" { - t.Errorf("failed to expand m, got %s\n%v", m.Expand(tx), m) + t.Fatalf("failed to expand m, got %s\n%v", m.Expand(tx), m) } _, err = macro.NewMacro("some complex text %{tx.some} wrapped in m %{tx.some}") if err != nil { - t.Error(err) + t.Fatal(err) return } // TODO(anuraaga): Decouple this test from transaction implementation. // if !macro.IsExpandable() || len(macro.tokens) != 4 || macro.Expand(tx) != "some complex text secretly wrapped in m secretly" { - // t.Errorf("failed to parse replacements %v", macro.tokens) + // t.Fatalf("failed to parse replacements %v", macro.tokens) // } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } } func BenchmarkMacro(b *testing.B) { @@ -1336,6 +1398,10 @@ func TestProcessorsIdempotencyWithAlreadyRaisedInterruption(t *testing.T) { } }) } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } } func TestIterationStops(t *testing.T) { @@ -1361,16 +1427,20 @@ func TestIterationStops(t *testing.T) { }) if want, have := i+1, len(haveVars); want != have { - t.Errorf("stopped with unexpected number of variables, want %d, have %d", want, have) + t.Fatalf("stopped with unexpected number of variables, want %d, have %d", want, have) } for j, v := range haveVars { if want, have := allVars[j], v; want != have { - t.Errorf("unexpected variable at index %d, want %s, have %s", j, want.Name(), have.Name()) + t.Fatalf("unexpected variable at index %d, want %s, have %s", j, want.Name(), have.Name()) } } }) } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } } func TestTxAddResponseArgs(t *testing.T) { @@ -1378,7 +1448,7 @@ func TestTxAddResponseArgs(t *testing.T) { tx := waf.NewTransaction() tx.AddResponseArgument("samplekey", "samplevalue") if tx.variables.responseArgs.Get("samplekey")[0] != "samplevalue" { - t.Errorf("failed to add response argument") + t.Fatalf("failed to add response argument") } } @@ -1395,6 +1465,10 @@ func TestAddGetArgsWithOverlimit(t *testing.T) { if tx.variables.argsGet.Len() > waf.ArgumentLimit { t.Fatal("Argument limit is failed while add get args") } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } } } @@ -1411,6 +1485,10 @@ func TestAddPostArgsWithOverlimit(t *testing.T) { if tx.variables.argsPost.Len() > waf.ArgumentLimit { t.Fatal("Argument limit is failed while add post args") } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } } } @@ -1427,6 +1505,10 @@ func TestAddPathArgsWithOverlimit(t *testing.T) { if tx.variables.argsPath.Len() > waf.ArgumentLimit { t.Fatal("Argument limit is failed while add path args") } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } } } @@ -1443,6 +1525,10 @@ func TestAddResponseArgsWithOverlimit(t *testing.T) { if tx.variables.responseArgs.Len() > waf.ArgumentLimit { t.Fatal("Argument limit is failed while add response args") } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } } } @@ -1467,6 +1553,10 @@ func TestResponseBodyForceProcessing(t *testing.T) { if len(f) == 0 { t.Fatal("json.key not found") } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) + } } func TestForceRequestBodyOverride(t *testing.T) { @@ -1477,24 +1567,28 @@ func TestForceRequestBodyOverride(t *testing.T) { tx.variables.RequestBodyProcessor().(*collections.Single).Set("JSON") tx.ProcessRequestHeaders() if _, _, err := tx.WriteRequestBody([]byte("foo=bar&baz=qux")); err != nil { - t.Errorf("Failed to write request body: %v", err) + t.Fatalf("Failed to write request body: %v", err) } if _, err := tx.ProcessRequestBody(); err != nil { - t.Errorf("Failed to process request body: %v", err) + t.Fatalf("Failed to process request body: %v", err) } if tx.variables.RequestBodyProcessor().Get() != "JSON" { - t.Errorf("Failed to force request body variable") + t.Fatalf("Failed to force request body variable") } tx = waf.NewTransaction() tx.ForceRequestBodyVariable = true tx.ProcessRequestHeaders() if _, _, err := tx.WriteRequestBody([]byte("foo=bar&baz=qux")); err != nil { - t.Errorf("Failed to write request body: %v", err) + t.Fatalf("Failed to write request body: %v", err) } if _, err := tx.ProcessRequestBody(); err != nil { - t.Errorf("Failed to process request body: %v", err) + t.Fatalf("Failed to process request body: %v", err) } if tx.variables.RequestBodyProcessor().Get() != "URLENCODED" { - t.Errorf("Failed to force request body variable, got RBP: %q", tx.variables.RequestBodyProcessor().Get()) + t.Fatalf("Failed to force request body variable, got RBP: %q", tx.variables.RequestBodyProcessor().Get()) + } + + if err := tx.Close(); err != nil { + t.Fatalf("failed to close transaction: %s", err.Error()) } }