Skip to content

Commit

Permalink
fix:force require 'A' in Auditlogparts (#801)
Browse files Browse the repository at this point in the history
* fix:require 'A' in Auditlogparts,force to break AuditLog when reached 'Z'

* feat: drops the A case in Transation.AuditLog as it implies A and Z will be present and in the right position due to the types parsing.

* breaking: drops A and Z types for auditlog parts as per @anuraaga suggestion.

Currently they don't need a type as parser can enforce their presence and they are always expected to be present in the same position.

---------

Co-authored-by: José Carlos Chávez <[email protected]>
  • Loading branch information
Hayak3 and jcchavezs authored May 31, 2023
1 parent 721d1de commit ad50864
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 76 deletions.
4 changes: 2 additions & 2 deletions internal/actions/ctl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ func TestCtl(t *testing.T) {
},
},
"auditLogParts": {
input: "auditLogParts=A",
input: "auditLogParts=ABZ",
checkTX: func(t *testing.T, tx *corazawaf.Transaction, logEntry string) {
if want, have := types.AuditLogPartAuditLogHeader, tx.AuditLogParts[0]; want != have {
if want, have := types.AuditLogPartRequestHeaders, tx.AuditLogParts[0]; want != have {
t.Errorf("Failed to set audit log parts, want %s, have %s", string(want), string(have))
}
},
Expand Down
7 changes: 3 additions & 4 deletions internal/auditlog/formats.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ func nativeFormatter(al plugintypes.AuditLog) ([]byte, error) {
res.WriteString(boundaryPrefix)
res.WriteByte(byte(part))
res.WriteString("--\n")
// [27/Jul/2016:05:46:16 +0200] V5guiH8AAQEAADTeJ2wAAAAK 192.168.3.1 50084 192.168.3.111 80
_, _ = fmt.Fprintf(&res, "[%s] %s %s %d %s %d", al.Transaction().Timestamp(), al.Transaction().ID(),
al.Transaction().ClientIP(), al.Transaction().ClientPort(), al.Transaction().HostIP(), al.Transaction().HostPort())
switch part {
case types.AuditLogPartAuditLogHeader:
// [27/Jul/2016:05:46:16 +0200] V5guiH8AAQEAADTeJ2wAAAAK 192.168.3.1 50084 192.168.3.111 80
_, _ = fmt.Fprintf(&res, "[%s] %s %s %d %s %d", al.Transaction().Timestamp(), al.Transaction().ID(),
al.Transaction().ClientIP(), al.Transaction().ClientPort(), al.Transaction().HostIP(), al.Transaction().HostPort())
case types.AuditLogPartRequestHeaders:
// GET /url HTTP/1.1
// Host: example.com
Expand Down
1 change: 0 additions & 1 deletion internal/auditlog/formats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ func TestNativeFormatter(t *testing.T) {
func createAuditLog() *Log {
return &Log{
Parts_: []types.AuditLogPart{
types.AuditLogPartAuditLogHeader,
types.AuditLogPartRequestHeaders,
types.AuditLogPartRequestBody,
types.AuditLogPartIntermediaryResponseBody,
Expand Down
68 changes: 34 additions & 34 deletions internal/corazawaf/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -1301,41 +1301,42 @@ func (tx *Transaction) LastPhase() types.RulePhase {
return tx.lastPhase
}

// AuditLog returns an AuditLog struct, used to write audit logs
// AuditLog returns an AuditLog struct, used to write audit logs.
// It implies the log parts starts with A and ends with Z as in the
// types.ParseAuditLogParts.
func (tx *Transaction) AuditLog() *auditlog.Log {
al := &auditlog.Log{}
al.Parts_ = tx.AuditLogParts

var alTransaction auditlog.Transaction
clientPort, _ := strconv.Atoi(tx.variables.remotePort.Get())
hostPort, _ := strconv.Atoi(tx.variables.serverPort.Get())
// YYYY/MM/DD HH:mm:ss
ts := time.Unix(0, tx.Timestamp).Format("2006/01/02 15:04:05")
al.Transaction_ = auditlog.Transaction{
Timestamp_: ts,
UnixTimestamp_: tx.Timestamp,
ID_: tx.id,
ClientIP_: tx.variables.remoteAddr.Get(),
ClientPort_: clientPort,
HostIP_: tx.variables.serverAddr.Get(),
HostPort_: hostPort,
ServerID_: tx.variables.serverName.Get(), // TODO check
}

for _, part := range tx.AuditLogParts {
switch part {
case types.AuditLogPartAuditLogHeader:
clientPort, _ := strconv.Atoi(tx.variables.remotePort.Get())
hostPort, _ := strconv.Atoi(tx.variables.serverPort.Get())
// YYYY/MM/DD HH:mm:ss
ts := time.Unix(0, tx.Timestamp).Format("2006/01/02 15:04:05")
alTransaction = auditlog.Transaction{
Timestamp_: ts,
UnixTimestamp_: tx.Timestamp,
ID_: tx.id,
ClientIP_: tx.variables.remoteAddr.Get(),
ClientPort_: clientPort,
HostIP_: tx.variables.serverAddr.Get(),
HostPort_: hostPort,
ServerID_: tx.variables.serverName.Get(), // TODO check
}
case types.AuditLogPartRequestHeaders:
if alTransaction.Request_ == nil {
alTransaction.Request_ = &auditlog.TransactionRequest{}
if al.Transaction_.Request_ == nil {
al.Transaction_.Request_ = &auditlog.TransactionRequest{}
}
alTransaction.Request_.Headers_ = tx.variables.requestHeaders.Data()
al.Transaction_.Request_.Headers_ = tx.variables.requestHeaders.Data()
case types.AuditLogPartRequestBody:
if alTransaction.Request_ == nil {
alTransaction.Request_ = &auditlog.TransactionRequest{}
if al.Transaction_.Request_ == nil {
al.Transaction_.Request_ = &auditlog.TransactionRequest{}
}
// TODO maybe change to:
// al.Transaction.Request.Body = tx.RequestBodyBuffer.String()
alTransaction.Request_.Body_ = tx.variables.requestBody.Get()
al.Transaction_.Request_.Body_ = tx.variables.requestBody.Get()

/*
* TODO:
Expand All @@ -1347,7 +1348,7 @@ func (tx *Transaction) AuditLog() *auditlog.Log {
*/
// upload data
var files []plugintypes.AuditLogTransactionRequestFiles
alTransaction.Request_.Files_ = nil
al.Transaction_.Request_.Files_ = nil
for _, file := range tx.variables.files.Get("") {
var size int64
if fs := tx.variables.filesSizes.Get(file); len(fs) > 0 {
Expand All @@ -1362,21 +1363,21 @@ func (tx *Transaction) AuditLog() *auditlog.Log {
}
files = append(files, at)
}
alTransaction.Request_.Files_ = files
al.Transaction_.Request_.Files_ = files
case types.AuditLogPartIntermediaryResponseBody:
if alTransaction.Response_ == nil {
alTransaction.Response_ = &auditlog.TransactionResponse{}
if al.Transaction_.Response_ == nil {
al.Transaction_.Response_ = &auditlog.TransactionResponse{}
}
alTransaction.Response_.Body_ = tx.variables.responseBody.Get()
al.Transaction_.Response_.Body_ = tx.variables.responseBody.Get()
case types.AuditLogPartResponseHeaders:
if alTransaction.Response_ == nil {
alTransaction.Response_ = &auditlog.TransactionResponse{}
if al.Transaction_.Response_ == nil {
al.Transaction_.Response_ = &auditlog.TransactionResponse{}
}
status, _ := strconv.Atoi(tx.variables.responseStatus.Get())
alTransaction.Response_.Status_ = status
alTransaction.Response_.Headers_ = tx.variables.responseHeaders.Data()
al.Transaction_.Response_.Status_ = status
al.Transaction_.Response_.Headers_ = tx.variables.responseHeaders.Data()
case types.AuditLogPartAuditLogTrailer:
alTransaction.Producer_ = &auditlog.TransactionProducer{
al.Transaction_.Producer_ = &auditlog.TransactionProducer{
Connector_: tx.WAF.ProducerConnector,
Version_: tx.WAF.ProducerConnectorVersion,
Server_: "",
Expand Down Expand Up @@ -1411,7 +1412,6 @@ func (tx *Transaction) AuditLog() *auditlog.Log {
}
}

al.Transaction_ = alTransaction
return al
}

Expand Down
24 changes: 0 additions & 24 deletions internal/corazawaf/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,29 +430,6 @@ func TestAuditLog(t *testing.T) {
}
}

func TestParseAuditLog(t *testing.T) {
AuditLogParts, err := types.ParseAuditLogParts("ABCDEFGHIJK")
if err != nil {
t.Error("unexpected audit log parts")
}
expected := types.AuditLogParts("ABCDEFGHIJK")
if len(AuditLogParts) != len(expected) {
t.Error("AuditLogParts has different length than expected")

}
for i := 0; i < len(AuditLogParts); i++ {
if AuditLogParts[i] != expected[i] {
t.Errorf("Byte at position %d differs", i)
}
}
}
func TestInvalidAuditLog(t *testing.T) {
AuditLogParts, err := types.ParseAuditLogParts("ABCDEFGHIJKLMN")
if err == nil || len(AuditLogParts) != 0 {
t.Error("AuditLogParts should fail of invalid part")
}
}

var responseBodyWriters = map[string]func(tx *Transaction, body string) (*types.Interruption, int, error){
"WriteResponsequestBody": func(tx *Transaction, body string) (*types.Interruption, int, error) {
return tx.WriteResponseBody([]byte(body))
Expand Down Expand Up @@ -1106,7 +1083,6 @@ func TestTxSetServerName(t *testing.T) {
if want, have := "SetServerName has been called after ProcessRequestHeaders", logEntries[0]; !strings.Contains(have, want) {
t.Fatalf("unexpected message, want %q, have %q", want, have)
}

}

func TestTxAddArgument(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion internal/seclang/rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func TestSecAuditLogs(t *testing.T) {
err := parser.FromString(`
SecAuditEngine On
SecAction "id:4482,log,auditlog, msg:'test'"
SecAuditLogParts ABCDEFGHIJK
SecAuditLogParts ABCDEFGHIJKZ
SecRuleEngine On
`)
if err != nil {
Expand Down
25 changes: 15 additions & 10 deletions types/waf.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package types

import (
"errors"
"fmt"
"strings"
)
Expand Down Expand Up @@ -107,7 +108,6 @@ type AuditLogPart byte
type AuditLogParts []AuditLogPart

var validOpts = map[AuditLogPart]struct{}{
AuditLogPartAuditLogHeader: {},
AuditLogPartRequestHeaders: {},
AuditLogPartRequestBody: {},
AuditLogPartIntermediaryResponseHeaders: {},
Expand All @@ -118,21 +118,28 @@ var validOpts = map[AuditLogPart]struct{}{
AuditLogPartRequestBodyAlternative: {},
AuditLogPartUploadedFiles: {},
AuditLogPartRulesMatched: {},
AuditLogPartFinalBoundary: {},
}

// ParseAuditLogParts parses the audit log parts
func ParseAuditLogParts(opts string) (AuditLogParts, error) {
for _, opt := range opts {
if _, ok := validOpts[AuditLogPart(opt)]; !ok {
return AuditLogParts(""), fmt.Errorf("invalid audit log part: %s", opts)
if !strings.HasPrefix(opts, "A") {
return nil, errors.New("audit log parts is required to start with A")
}

if !strings.HasSuffix(opts, "Z") {
return nil, errors.New("audit log parts is required to end with Z")
}

parts := opts[1 : len(opts)-1]
for _, p := range parts {
if _, ok := validOpts[AuditLogPart(p)]; !ok {
return AuditLogParts(""), fmt.Errorf("invalid audit log parts %q", opts)
}
}
return AuditLogParts(opts), nil
return AuditLogParts(parts), nil
}

const (
// AuditLogPartAuditLogHeader is the mandatory header part
AuditLogPartAuditLogHeader AuditLogPart = 'A'
// AuditLogPartRequestHeaders is the request headers part
AuditLogPartRequestHeaders AuditLogPart = 'B'
// AuditLogPartRequestBody is the request body part
Expand All @@ -153,8 +160,6 @@ const (
AuditLogPartUploadedFiles AuditLogPart = 'J'
// AuditLogPartRulesMatched is the matched rules part
AuditLogPartRulesMatched AuditLogPart = 'K'
// AuditLogPartFinalBoundary is the mandatory final boundary part
AuditLogPartFinalBoundary AuditLogPart = 'Z'
)

// Interruption is used to notify the Coraza implementation
Expand Down
42 changes: 42 additions & 0 deletions types/waf_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package types

import "testing"

func TestParseAuditLogParts(t *testing.T) {
tests := []struct {
input string
expectedParts AuditLogParts
expectedHasError bool
}{
{"", nil, true},
{"ABCDEFGHIJKZ", []AuditLogPart("BCDEFGHIJK"), false},
{"DEFGHZ", nil, true},
{"ABCD", nil, true},
{"AMZ", nil, true},
}

for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
parts, err := ParseAuditLogParts(test.input)
if test.expectedHasError {
if err == nil {
t.Error("expected error")
}
} else {
if err != nil {
t.Error("unexpected error")
}

if want, have := len(test.expectedParts), len(parts); want != have {
t.Errorf("unexpected parts length, want %d, have %d", want, have)
}

for i, part := range test.expectedParts {
if want, have := part, parts[i]; want != have {
t.Errorf("unexpected part, want %q, have %q", want, have)
}
}
}
})
}
}

0 comments on commit ad50864

Please sign in to comment.