From 1455f1d15b7773b54667bf6c72990fda5096e743 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Ch=C3=A1vez?= Date: Thu, 9 May 2024 09:57:58 +0200 Subject: [PATCH] fix: deletes content temporary file on close. --- internal/corazawaf/transaction.go | 29 ++- internal/corazawaf/transaction_test.go | 305 +++++++++++++++++-------- 2 files changed, 227 insertions(+), 107 deletions(-) diff --git a/internal/corazawaf/transaction.go b/internal/corazawaf/transaction.go index 83459d0d4..b118da513 100644 --- a/internal/corazawaf/transaction.go +++ b/internal/corazawaf/transaction.go @@ -12,6 +12,7 @@ import ( "math" "mime" "net/url" + "os" "path/filepath" "strconv" "strings" @@ -26,6 +27,7 @@ import ( "github.com/corazawaf/coraza/v3/internal/cookies" "github.com/corazawaf/coraza/v3/internal/corazarules" "github.com/corazawaf/coraza/v3/internal/corazatypes" + "github.com/corazawaf/coraza/v3/internal/environment" stringsutil "github.com/corazawaf/coraza/v3/internal/strings" urlutil "github.com/corazawaf/coraza/v3/internal/url" "github.com/corazawaf/coraza/v3/types" @@ -1472,13 +1474,25 @@ func (tx *Transaction) AuditLog() *auditlog.Log { // It also allows caches the transaction back into the sync.Pool func (tx *Transaction) Close() error { defer tx.WAF.txPool.Put(tx) - tx.variables.reset() + var errs []error + if environment.HasAccessToFS { + // TODO(jcchavezs): filesTmpNames should probably be a new kind of collection that + // is aware of the files and then attempt to delete them when the collection + // is resetted or an item is removed. + for _, file := range tx.variables.filesTmpNames.Get("") { + if err := os.Remove(file); err != nil { + errs = append(errs, fmt.Errorf("removing temporary file: %v", err)) + } + } + } + + tx.variables.reset() if err := tx.requestBodyBuffer.Reset(); err != nil { - errs = append(errs, err) + errs = append(errs, fmt.Errorf("reseting request body buffer: %v", err)) } if err := tx.responseBodyBuffer.Reset(); err != nil { - errs = append(errs, err) + errs = append(errs, fmt.Errorf("reseting response body buffer: %v", err)) } if tx.IsInterrupted() { @@ -1493,14 +1507,11 @@ func (tx *Transaction) Close() error { Msg("Transaction finished") } - switch { - case len(errs) == 0: + if len(errs) == 0 { return nil - case len(errs) == 1: - return fmt.Errorf("transaction close failed: %s", errs[0].Error()) - default: - return fmt.Errorf("transaction close failed:\n- %s\n- %s", errs[0].Error(), errs[1].Error()) } + + return fmt.Errorf("transaction close failed: %v", errors.Join(errs...)) } // String will return a string with the transaction debug information diff --git a/internal/corazawaf/transaction_test.go b/internal/corazawaf/transaction_test.go index b1db6a86a..1ef7fefc8 100644 --- a/internal/corazawaf/transaction_test.go +++ b/internal/corazawaf/transaction_test.go @@ -98,7 +98,7 @@ func TestTxMultipart(t *testing.T) { tx.RequestBodyLimit = 9999999 _, err := tx.ParseRequestReader(strings.NewReader(data)) if err != nil { - t.Error("Failed to parse multipart request: " + err.Error()) + t.Fatal("Failed to parse multipart request: " + err.Error()) } exp := map[string]string{ "%{args_post.text}": "test-value", @@ -108,6 +108,10 @@ func TestTxMultipart(t *testing.T) { } validateMacroExpansion(exp, tx, t) + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } } func TestTxResponse(t *testing.T) { @@ -225,7 +229,7 @@ func TestWriteRequestBody(t *testing.T) { for _, c := range chunks { if it, _, err = writeRequestBody(tx, c); err != nil { - t.Errorf("Failed to write body buffer: %s", err.Error()) + t.Fatalf("Failed to write body buffer: %s", err.Error()) } } @@ -245,11 +249,13 @@ func TestWriteRequestBody(t *testing.T) { val := tx.variables.argsPost.Get("some") if len(val) != 1 || val[0] != "result" { - t.Errorf("Failed to set urlencoded POST data with arguments: \"%s\"", strings.Join(val, "\", \"")) + t.Fatalf("Failed to set urlencoded POST data with arguments: \"%s\"", strings.Join(val, "\", \"")) } } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } }) } @@ -306,7 +312,9 @@ func TestWriteRequestBodyOnLimitReached(t *testing.T) { t.Fatalf("unexpected number of bytes written") } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } }) } }) @@ -353,7 +361,9 @@ func TestWriteRequestBodyIsNopWhenBodyIsNotAccesible(t *testing.T) { t.Fatalf("unexpected number of bytes written") } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } }) } }) @@ -364,12 +374,12 @@ func TestResponseHeader(t *testing.T) { tx := makeTransaction(t) tx.AddResponseHeader("content-type", "test") if tx.variables.responseContentType.Get() != "test" { - t.Error("invalid RESPONSE_CONTENT_TYPE after response headers") + t.Fatal("invalid RESPONSE_CONTENT_TYPE after response headers") } interruption := tx.ProcessResponseHeaders(200, "OK") if interruption != nil { - t.Error("unexpected interruption") + t.Fatal("unexpected interruption") } } @@ -378,12 +388,16 @@ func TestProcessRequestHeadersDoesNoEvaluationOnEngineOff(t *testing.T) { tx.RuleEngine = types.RuleEngineOff if !tx.IsRuleEngineOff() { - t.Error("expected Engine off") + t.Fatal("expected Engine off") } _ = tx.ProcessRequestHeaders() if tx.lastPhase != 0 { // 0 means no phases have been evaluated - t.Error("unexpected rule evaluation") + t.Fatal("unexpected rule evaluation") + } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -391,10 +405,13 @@ func TestProcessRequestBodyDoesNoEvaluationOnEngineOff(t *testing.T) { tx := NewWAF().NewTransaction() tx.RuleEngine = types.RuleEngineOff if _, err := tx.ProcessRequestBody(); err != nil { - t.Error("failed to process request body") + t.Fatal("failed to process request body") } if tx.lastPhase != 0 { - t.Error("unexpected rule evaluation") + t.Fatal("unexpected rule evaluation") + } + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -403,7 +420,7 @@ func TestProcessResponseHeadersDoesNoEvaluationOnEngineOff(t *testing.T) { tx.RuleEngine = types.RuleEngineOff _ = tx.ProcessResponseHeaders(200, "OK") if tx.lastPhase != 0 { - t.Error("unexpected rule evaluation") + t.Fatal("unexpected rule evaluation") } } @@ -411,10 +428,10 @@ func TestProcessResponseBodyDoesNoEvaluationOnEngineOff(t *testing.T) { tx := NewWAF().NewTransaction() tx.RuleEngine = types.RuleEngineOff if _, err := tx.ProcessResponseBody(); err != nil { - t.Error("Failed to process response body") + t.Fatal("Failed to process response body") } if tx.lastPhase != 0 { - t.Error("unexpected rule evaluation") + t.Fatal("unexpected rule evaluation") } } @@ -423,7 +440,10 @@ func TestProcessLoggingDoesNoEvaluationOnEngineOff(t *testing.T) { tx.RuleEngine = types.RuleEngineOff tx.ProcessLogging() if tx.lastPhase != 0 { - t.Error("unexpected rule evaluation") + t.Fatal("unexpected rule evaluation") + } + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -432,11 +452,11 @@ func TestAuditLog(t *testing.T) { tx.AuditLogParts = types.AuditLogParts("ABCDEFGHIJK") al := tx.AuditLog() if al.Transaction().ID() != tx.id { - t.Error("invalid auditlog id") + t.Fatal("invalid auditlog id") } // TODO more checks if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -525,7 +545,7 @@ func TestWriteResponseBody(t *testing.T) { for _, c := range chunks { if it, _, err = writeResponseBody(tx, c); err != nil { - t.Errorf("Failed to write body buffer: %s", err.Error()) + t.Fatalf("Failed to write body buffer: %s", err.Error()) } } @@ -545,11 +565,13 @@ func TestWriteResponseBody(t *testing.T) { // checking if the body has been populated up to the first POST arg index := strings.Index(urlencodedBody, "&") if tx.variables.responseBody.Get()[:index] != urlencodedBody[:index] { - t.Error("failed to set response body") + t.Fatal("failed to set response body") } } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } }) } @@ -606,7 +628,9 @@ func TestWriteResponseBodyOnLimitReached(t *testing.T) { t.Fatalf("unexpected number of bytes written") } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } }) } }) @@ -653,7 +677,9 @@ func TestWriteResponseBodyIsNopWhenBodyIsNotAccesible(t *testing.T) { t.Fatalf("unexpected number of bytes written") } - _ = tx.Close() + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } }) } }) @@ -674,21 +700,21 @@ func TestAuditLogFields(t *testing.T) { }, }) if len(tx.matchedRules) == 0 || tx.matchedRules[0].Rule().ID() != rule.ID_ { - t.Error("failed to match rule for audit") + t.Fatal("failed to match rule for audit") } al := tx.AuditLog() if len(al.Messages()) == 0 || al.Messages()[0].Data().ID() != rule.ID_ { - t.Error("failed to add rules to audit logs") + t.Fatal("failed to add rules to audit logs") } if len(al.Transaction().Request().Headers()) == 0 || al.Transaction().Request().Headers()["test"][0] != "test" { - t.Error("failed to add request header to audit log") + t.Fatal("failed to add request header to audit log") } if len(al.Transaction().Response().Headers()) == 0 || al.Transaction().Response().Headers()["test"][0] != "test" { - t.Error("failed to add Response header to audit log") + t.Fatal("failed to add Response header to audit log") } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -697,14 +723,14 @@ func TestResetCapture(t *testing.T) { tx.Capture = true tx.CaptureField(5, "test") if tx.variables.tx.Get("5")[0] != "test" { - t.Error("failed to set capture field from tx") + t.Fatal("failed to set capture field from tx") } tx.resetCaptures() if tx.variables.tx.Get("5")[0] != "" { - t.Error("failed to reset capture field from tx") + t.Fatal("failed to reset capture field from tx") } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -827,13 +853,13 @@ func TestLogCallback(t *testing.T) { } if buffer == "" || !strings.Contains(buffer, tx.id) { - t.Error("failed to call error log callback") + t.Fatal("failed to call error log callback") } if !strings.Contains(buffer, testCase.expectedLogLine) { - t.Errorf("Expected string \"%s\" with disruptive rule, got %s", testCase.expectedLogLine, buffer) + t.Fatalf("Expected string \"%s\" with disruptive rule, got %s", testCase.expectedLogLine, buffer) if err := tx.Close(); err != nil { - t.Error(err) + t.Fatal(err) } } }) @@ -847,19 +873,19 @@ func TestHeaderSetters(t *testing.T) { tx.AddRequestHeader("test1", "test2") c := tx.variables.requestCookies.Get("abc")[0] if c != "def" { - t.Errorf("failed to set cookie, got %q", c) + t.Fatalf("failed to set cookie, got %q", c) } if tx.variables.requestHeaders.Get("cookie")[0] != "abc=def;hij=klm" { - t.Error("failed to set request header") + t.Fatal("failed to set request header") } if !utils.InSlice("cookie", collectionValues(t, tx.variables.requestHeadersNames)) { - t.Error("failed to set header name", collectionValues(t, tx.variables.requestHeadersNames)) + t.Fatal("failed to set header name", collectionValues(t, tx.variables.requestHeadersNames)) } if !utils.InSlice("abc", collectionValues(t, tx.variables.requestCookiesNames)) { - t.Error("failed to set cookie name") + t.Fatal("failed to set cookie name") } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -933,16 +959,16 @@ func TestRequestBodyProcessingAlgorithm(t *testing.T) { tx.AddRequestHeader("content-length", "7") tx.ProcessRequestHeaders() if _, err := tx.requestBodyBuffer.Write([]byte("test123")); err != nil { - t.Error("Failed to write request body buffer") + t.Fatal("Failed to write request body buffer") } if _, err := tx.ProcessRequestBody(); err != nil { - t.Error("failed to process request body") + t.Fatal("failed to process request body") } if tx.variables.requestBody.Get() != "test123" { - t.Error("failed to set request body") + t.Fatal("failed to set request body") } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -984,7 +1010,7 @@ func TestProcessBodiesSkippedIfHeadersPhasesNotReached(t *testing.T) { t.Fatalf("unexpected message, want %q, have %q", want, have) } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -996,31 +1022,31 @@ func TestTxVariables(t *testing.T) { KeyRx: regexp.MustCompile("ho.*"), } if len(tx.GetField(rv)) != 1 || tx.GetField(rv)[0].Value() != "www.test.com:80" { - t.Errorf("failed to match rule variable REQUEST_HEADERS:host, %d matches, %v", len(tx.GetField(rv)), tx.GetField(rv)) + t.Fatalf("failed to match rule variable REQUEST_HEADERS:host, %d matches, %v", len(tx.GetField(rv)), tx.GetField(rv)) } rv.Count = true if len(tx.GetField(rv)) == 0 || tx.GetField(rv)[0].Value() != "1" { - t.Errorf("failed to get count for regexp variable") + t.Fatalf("failed to get count for regexp variable") } // now nil key rv.KeyRx = nil if len(tx.GetField(rv)) == 0 { - t.Error("failed to match rule variable REQUEST_HEADERS with nil key") + t.Fatal("failed to match rule variable REQUEST_HEADERS with nil key") } rv.KeyStr = "" f := tx.GetField(rv) if len(f) == 0 { - t.Error("failed to count variable REQUEST_HEADERS ") + t.Fatal("failed to count variable REQUEST_HEADERS ") } count, err := strconv.Atoi(f[0].Value()) if err != nil { - t.Error(err) + t.Fatal(err) } if count != 5 { - t.Errorf("failed to match rule variable REQUEST_HEADERS with count, %v", rv) + t.Fatalf("failed to match rule variable REQUEST_HEADERS with count, %v", rv) } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatal(err) } } @@ -1036,12 +1062,12 @@ func TestTxVariablesExceptions(t *testing.T) { } fields := tx.GetField(rv) if len(fields) != 0 { - t.Errorf("REQUEST_HEADERS:host should not match, got %d matches, %v", len(fields), fields) + t.Fatalf("REQUEST_HEADERS:host should not match, got %d matches, %v", len(fields), fields) } rv.Exceptions = nil fields = tx.GetField(rv) if len(fields) != 1 || fields[0].Value() != "www.test.com:80" { - t.Errorf("failed to match rule variable REQUEST_HEADERS:host, %d matches, %v", len(fields), fields) + t.Fatalf("failed to match rule variable REQUEST_HEADERS:host, %d matches, %v", len(fields), fields) } rv.Exceptions = []ruleVariableException{ { @@ -1050,10 +1076,10 @@ func TestTxVariablesExceptions(t *testing.T) { } fields = tx.GetField(rv) if len(fields) != 0 { - t.Errorf("REQUEST_HEADERS:host should not match, got %d matches, %v", len(fields), fields) + t.Fatalf("REQUEST_HEADERS:host should not match, got %d matches, %v", len(fields), fields) } if err := tx.Close(); err != nil { - t.Error(err) + t.Fatal(err) } } @@ -1067,11 +1093,11 @@ func TestTransactionSyncPool(t *testing.T) { }) for i := 0; i < 1000; i++ { if err := tx.Close(); err != nil { - t.Error(err) + t.Fatal(err) } tx = waf.NewTransaction() if len(tx.matchedRules) != 0 { - t.Errorf("failed to sync transaction pool, %d rules found after %d attempts", len(tx.matchedRules), i+1) + t.Fatalf("failed to sync transaction pool, %d rules found after %d attempts", len(tx.matchedRules), i+1) return } } @@ -1089,16 +1115,16 @@ func TestTxPhase4Magic(t *testing.T) { _, _ = tx.ProcessRequestBody() tx.ProcessResponseHeaders(200, "HTTP/1.1") if it, _, err := tx.WriteResponseBody([]byte("more bytes")); it != nil || err != nil { - t.Error(err) + t.Fatal(err) } if _, err := tx.ProcessResponseBody(); err != nil { - t.Error(err) + t.Fatal(err) } if tx.variables.outboundDataError.Get() != "1" { - t.Error("failed to set outbound data error") + t.Fatal("failed to set outbound data error") } if tx.variables.responseBody.Get() != "mor" { - t.Error("failed to set response body") + t.Fatal("failed to set response body") } } @@ -1117,12 +1143,16 @@ func TestVariablesMatch(t *testing.T) { for k, v := range expect { if m := (tx.Collection(k)).(*collections.Single).Get(); m != v { - t.Errorf("failed to match variable %s, Expected: %s, got: %s", k.Name(), v, m) + t.Fatalf("failed to match variable %s, Expected: %s, got: %s", k.Name(), v, m) } } if len(tx.variables.matchedVars.Get("ARGS_NAMES:sample")) == 0 { - t.Errorf("failed to match variable %s, got 0", variables.MatchedVars.Name()) + t.Fatalf("failed to match variable %s, got 0", variables.MatchedVars.Name()) + } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -1133,13 +1163,17 @@ func TestTxReqBodyForce(t *testing.T) { tx.RequestBodyAccess = true tx.ForceRequestBodyVariable = true if _, err := tx.requestBodyBuffer.Write([]byte("test")); err != nil { - t.Error(err) + t.Fatal(err) } if _, err := tx.ProcessRequestBody(); err != nil { - t.Error(err) + t.Fatal(err) } if tx.variables.requestBody.Get() != "test" { - t.Error("failed to set request body") + t.Fatal("failed to set request body") + } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -1149,13 +1183,17 @@ func TestTxReqBodyForceNegative(t *testing.T) { tx.RequestBodyAccess = true tx.ForceRequestBodyVariable = false if _, err := tx.requestBodyBuffer.Write([]byte("test")); err != nil { - t.Error(err) + t.Fatal(err) } if _, err := tx.ProcessRequestBody(); err != nil { - t.Error(err) + t.Fatal(err) } if tx.variables.requestBody.Get() == "test" { - t.Error("reqbody should not be there") + t.Fatal("reqbody should not be there") + } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -1164,10 +1202,14 @@ func TestTxProcessConnection(t *testing.T) { tx := waf.NewTransaction() tx.ProcessConnection("127.0.0.1", 80, "127.0.0.2", 8080) if tx.variables.remoteAddr.Get() != "127.0.0.1" { - t.Error("failed to set client ip") + t.Fatal("failed to set client ip") } if rp, _ := strconv.Atoi(tx.variables.remotePort.Get()); rp != 80 { - t.Error("failed to set client port") + t.Fatal("failed to set client port") + } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -1182,7 +1224,7 @@ func TestTxSetServerName(t *testing.T) { tx.lastPhase = types.PhaseRequestHeaders tx.SetServerName("coraza.io") if tx.variables.serverName.Get() != "coraza.io" { - t.Error("failed to set server name") + t.Fatal("failed to set server name") } logEntries := strings.Split(strings.TrimSpace(logBuffer.String()), "\n") if want, have := 1, len(logEntries); want != have { @@ -1192,6 +1234,10 @@ func TestTxSetServerName(t *testing.T) { if want, have := "SetServerName has been called after ProcessRequestHeaders", logEntries[0]; !strings.Contains(have, want) { t.Fatalf("unexpected message, want %q, have %q", want, have) } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } } func TestTxAddArgument(t *testing.T) { @@ -1200,15 +1246,19 @@ func TestTxAddArgument(t *testing.T) { tx.ProcessConnection("127.0.0.1", 80, "127.0.0.2", 8080) tx.AddGetRequestArgument("test", "testvalue") if tx.variables.argsGet.Get("test")[0] != "testvalue" { - t.Error("failed to set args get") + t.Fatal("failed to set args get") } tx.AddPostRequestArgument("ptest", "ptestvalue") if tx.variables.argsPost.Get("ptest")[0] != "ptestvalue" { - t.Error("failed to set args post") + t.Fatal("failed to set args post") } tx.AddPathRequestArgument("ptest2", "ptestvalue") if tx.variables.argsPath.Get("ptest2")[0] != "ptestvalue" { - t.Error("failed to set args post") + t.Fatal("failed to set args post") + } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -1218,7 +1268,11 @@ func TestTxGetField(t *testing.T) { Variable: variables.Args, } if f := tx.GetField(rvp); len(f) != 3 { - t.Errorf("failed to get field, expected 2, got %d", len(f)) + t.Fatalf("failed to get field, expected 2, got %d", len(f)) + } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -1228,19 +1282,23 @@ func TestTxProcessURI(t *testing.T) { uri := "http://example.com/path/to/file.html?query=string&other=value" tx.ProcessURI(uri, "GET", "HTTP/1.1") if s := tx.variables.requestURI.Get(); s != uri { - t.Errorf("failed to set request uri, got %s", s) + t.Fatalf("failed to set request uri, got %s", s) } if s := tx.variables.requestBasename.Get(); s != "file.html" { - t.Errorf("failed to set request path, got %s", s) + t.Fatalf("failed to set request path, got %s", s) } if tx.variables.queryString.Get() != "query=string&other=value" { - t.Error("failed to set request query") + t.Fatal("failed to set request query") } if v := tx.variables.args.FindAll(); len(v) != 2 { - t.Errorf("failed to set request args, got %d", len(v)) + t.Fatalf("failed to set request args, got %d", len(v)) } if v := tx.variables.args.FindString("other"); v[0].Value() != "value" { - t.Errorf("failed to set request args, got %v", v) + t.Fatalf("failed to set request args, got %v", v) + } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) } } @@ -1316,7 +1374,7 @@ func validateMacroExpansion(tests map[string]string, tx *Transaction, t *testing for k, v := range tests { m, err := macro.NewMacro(k) if err != nil { - t.Error(err) + t.Fatal(err) } res := m.Expand(tx) if res != v { @@ -1324,7 +1382,7 @@ func validateMacroExpansion(tests map[string]string, tx *Transaction, t *testing fmt.Println(tx) fmt.Println("===STACK===\n", string(debug.Stack())+"\n===STACK===") } - t.Error("Failed set transaction for " + k + ", expected " + v + ", got " + res) + t.Fatal("Failed set transaction for " + k + ", expected " + v + ", got " + res) } } } @@ -1334,28 +1392,32 @@ func TestMacro(t *testing.T) { tx.variables.tx.Set("some", []string{"secretly"}) m, err := macro.NewMacro("%{unique_id}") if err != nil { - t.Error(err) + t.Fatal(err) } if m.Expand(tx) != tx.id { - t.Errorf("%s != %s", m.Expand(tx), tx.id) + t.Fatalf("%s != %s", m.Expand(tx), tx.id) } m, err = macro.NewMacro("some complex text %{tx.some} wrapped in m") if err != nil { - t.Error(err) + t.Fatal(err) } if m.Expand(tx) != "some complex text secretly wrapped in m" { - t.Errorf("failed to expand m, got %s\n%v", m.Expand(tx), m) + t.Fatalf("failed to expand m, got %s\n%v", m.Expand(tx), m) } _, err = macro.NewMacro("some complex text %{tx.some} wrapped in m %{tx.some}") if err != nil { - t.Error(err) + t.Fatal(err) return } // TODO(anuraaga): Decouple this test from transaction implementation. // if !macro.IsExpandable() || len(macro.tokens) != 4 || macro.Expand(tx) != "some complex text secretly wrapped in m secretly" { - // t.Errorf("failed to parse replacements %v", macro.tokens) + // t.Fatalf("failed to parse replacements %v", macro.tokens) // } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } } func BenchmarkMacro(b *testing.B) { @@ -1444,6 +1506,10 @@ func TestProcessorsIdempotencyWithAlreadyRaisedInterruption(t *testing.T) { } }) } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } } func TestIterationStops(t *testing.T) { @@ -1469,16 +1535,20 @@ func TestIterationStops(t *testing.T) { }) if want, have := i+1, len(haveVars); want != have { - t.Errorf("stopped with unexpected number of variables, want %d, have %d", want, have) + t.Fatalf("stopped with unexpected number of variables, want %d, have %d", want, have) } for j, v := range haveVars { if want, have := allVars[j], v; want != have { - t.Errorf("unexpected variable at index %d, want %s, have %s", j, want.Name(), have.Name()) + t.Fatalf("unexpected variable at index %d, want %s, have %s", j, want.Name(), have.Name()) } } }) } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } } func TestTxAddResponseArgs(t *testing.T) { @@ -1486,7 +1556,7 @@ func TestTxAddResponseArgs(t *testing.T) { tx := waf.NewTransaction() tx.AddResponseArgument("samplekey", "samplevalue") if tx.variables.responseArgs.Get("samplekey")[0] != "samplevalue" { - t.Errorf("failed to add response argument") + t.Fatalf("failed to add response argument") } } @@ -1503,6 +1573,10 @@ func TestAddGetArgsWithOverlimit(t *testing.T) { if tx.variables.argsGet.Len() > waf.ArgumentLimit { t.Fatal("Argument limit is failed while add get args") } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } } } @@ -1519,6 +1593,10 @@ func TestAddPostArgsWithOverlimit(t *testing.T) { if tx.variables.argsPost.Len() > waf.ArgumentLimit { t.Fatal("Argument limit is failed while add post args") } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } } } @@ -1535,6 +1613,10 @@ func TestAddPathArgsWithOverlimit(t *testing.T) { if tx.variables.argsPath.Len() > waf.ArgumentLimit { t.Fatal("Argument limit is failed while add path args") } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } } } @@ -1551,6 +1633,10 @@ func TestAddResponseArgsWithOverlimit(t *testing.T) { if tx.variables.responseArgs.Len() > waf.ArgumentLimit { t.Fatal("Argument limit is failed while add response args") } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } } } @@ -1575,6 +1661,10 @@ func TestResponseBodyForceProcessing(t *testing.T) { if len(f) == 0 { t.Fatal("json.key not found") } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } } func TestForceRequestBodyOverride(t *testing.T) { @@ -1585,24 +1675,43 @@ func TestForceRequestBodyOverride(t *testing.T) { tx.variables.RequestBodyProcessor().(*collections.Single).Set("JSON") tx.ProcessRequestHeaders() if _, _, err := tx.WriteRequestBody([]byte("foo=bar&baz=qux")); err != nil { - t.Errorf("Failed to write request body: %v", err) + t.Fatalf("Failed to write request body: %v", err) } if _, err := tx.ProcessRequestBody(); err != nil { - t.Errorf("Failed to process request body: %v", err) + t.Fatalf("Failed to process request body: %v", err) } if tx.variables.RequestBodyProcessor().Get() != "JSON" { - t.Errorf("Failed to force request body variable") + t.Fatalf("Failed to force request body variable") } tx = waf.NewTransaction() tx.ForceRequestBodyVariable = true tx.ProcessRequestHeaders() if _, _, err := tx.WriteRequestBody([]byte("foo=bar&baz=qux")); err != nil { - t.Errorf("Failed to write request body: %v", err) + t.Fatalf("Failed to write request body: %v", err) } if _, err := tx.ProcessRequestBody(); err != nil { - t.Errorf("Failed to process request body: %v", err) + t.Fatalf("Failed to process request body: %v", err) } if tx.variables.RequestBodyProcessor().Get() != "URLENCODED" { - t.Errorf("Failed to force request body variable, got RBP: %q", tx.variables.RequestBodyProcessor().Get()) + t.Fatalf("Failed to force request body variable, got RBP: %q", tx.variables.RequestBodyProcessor().Get()) + } + + if err := tx.Close(); err != nil { + t.Fatalf("Failed to close transaction: %s", err.Error()) + } +} + +func TestCloseFails(t *testing.T) { + waf := NewWAF() + tx := waf.NewTransaction() + col := tx.Variables().FilesTmpNames().(*collections.Map) + col.Add("", "unexisting") + err := tx.Close() + if err == nil { + t.Fatalf("expected error when closing transaction") + } + + if !strings.Contains(err.Error(), "removing temporary file") { + t.Fatalf("unexpected error message: %s", err.Error()) } }