From bfd416b1542d2a7bcebad6d1ebca0dd1a075ff16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bendik=20Nesb=C3=B8?= Date: Wed, 21 Feb 2024 12:50:19 +0100 Subject: [PATCH] fix: Prevent accidental mutation of fields. --- entry.go | 18 ++++++---- entry_test.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++++++ logger_test.go | 38 ++++++++++++++++++++ 3 files changed, 145 insertions(+), 6 deletions(-) create mode 100644 entry_test.go diff --git a/entry.go b/entry.go index e595633..d648b00 100644 --- a/entry.go +++ b/entry.go @@ -20,22 +20,28 @@ func (e *Entry) WithError(err error) *Entry { // WithField forwards a logging call with a field func (e *Entry) WithField(key string, value interface{}) *Entry { - e.fields[key] = value - return e + return e.WithFields(Fields{key: value}) } // WithFields forwards a logging call with fields func (e *Entry) WithFields(fields Fields) *Entry { + // Make a copy, to prevent mutation of the old entry + newFields := make(Fields, len(e.fields)+len(fields)) + // Copy old fields + for k, v := range e.fields { + newFields[k] = v + } + // Set new fields for k, v := range fields { - e.fields[k] = v + newFields[k] = v } - return e + return &Entry{logger: e.logger, fields: newFields, context: e.context} } // WithContext sets the context for the log-message. Useful when using hooks. func (e *Entry) WithContext(ctx context.Context) *Entry { - e.context = ctx - return e + // Make a copy, to prevent mutation of the old entry + return &Entry{logger: e.logger, fields: e.fields, context: ctx} } // Info forwards a logging call in the (format, args) format diff --git a/entry_test.go b/entry_test.go new file mode 100644 index 0000000..563b263 --- /dev/null +++ b/entry_test.go @@ -0,0 +1,95 @@ +package logger + +import ( + "context" + "fmt" + "testing" +) + +func TestNoFieldsMutation(t *testing.T) { + testCases := map[string]struct { + apply func(e *Entry) *Entry + expectedEntries int + }{ + "WithError": { + apply: func(e *Entry) *Entry { + return e.WithError(fmt.Errorf("some error")) + }, + expectedEntries: 1, + }, + "WithField": { + apply: func(e *Entry) *Entry { + return e.WithField("single", "value") + }, + expectedEntries: 1, + }, + "WithFields": { + apply: func(e *Entry) *Entry { + return e.WithFields(Fields{"multiple1": "value", "multiple2": "value"}) + }, + expectedEntries: 2, + }, + "WithField+WithFields": { + apply: func(e *Entry) *Entry { + return e.WithField("single", "value").WithFields(Fields{"multiple1": "value", "multiple2": "value"}) + }, + expectedEntries: 3, + }, + "WithField+WithError": { + apply: func(e *Entry) *Entry { + return e.WithField("single", "value").WithError(fmt.Errorf("some error")) + }, + expectedEntries: 2, + }, + "Override WithField": { + apply: func(e *Entry) *Entry { + return e.WithField("single", "value").WithField("single", "overridden-value") + }, + expectedEntries: 1, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + entry := &Entry{} + if len(entry.fields) != 0 { + t.Fatalf("A blank entry should not contain any fields. Found %d", len(entry.fields)) + } + + newEntry := tc.apply(entry) + if len(entry.fields) != 0 { + t.Fatalf("Applying a func mutated the original entry. Found %d entries in original entry.", len(entry.fields)) + } + if len(newEntry.fields) != tc.expectedEntries { + t.Fatalf("Applying a func did not add fields to the new entry. Expected %d, found %d", tc.expectedEntries, len(newEntry.fields)) + } + }) + } +} + +func TestNoContextMutation(t *testing.T) { + entry1 := &Entry{} + if entry1.context != nil { + t.Fatalf("A blank entry should not contain a context.") + } + ctx1 := context.Background() + ctx2 := context.TODO() + + entry2 := entry1.WithContext(ctx1) + if entry1.context != nil { + t.Fatalf("WithContext mutated the original entry1.") + } + if entry2.context != ctx1 { + t.Fatalf("WithContext did not set context in entry2.") + } + + entry3 := entry2.WithContext(ctx2) + if entry1.context != nil { + t.Fatalf("WithContext mutated the original entry1.") + } + if entry2.context != ctx1 { + t.Fatalf("The second WithContext mutated entry2.") + } + if entry3.context != ctx2 { + t.Fatalf("The second WithContext did not set context in entry3.") + } +} diff --git a/logger_test.go b/logger_test.go index f51110d..d91be2a 100644 --- a/logger_test.go +++ b/logger_test.go @@ -373,3 +373,41 @@ func TestReuseEntry(t *testing.T) { assertLogEntryContains(t, strings.NewReader(lines[1]), "msg", "baz") assertLogEntryContains(t, strings.NewReader(lines[1]), "level", "error") } + +func TestReuseEntryWithFields(t *testing.T) { + // This test asserts that entry.WithField(...) does mutate the entry itself + builder := &strings.Builder{} + logger := New(WithOutput(builder), WithLevel(LevelInfo)) + entryWithFields := logger.WithField("foo", "bar") + entryWithFields.WithField("only-quoo", true).Info("quoo") + entryWithFields.WithError(fmt.Errorf("only-baz")).Error("baz") + entryWithFields.Info("final") + + str := builder.String() + lines := strings.Split(str, "\n") + if len(lines) != 4 { + // Info + Error + empty newline + t.Fatalf("expected %d lines, got %d", 4, len(lines)) + } + if lines[3] != "" { + t.Fatalf("expected last line to be empty, got %s", lines[3]) + } + + assertLogEntryContains(t, strings.NewReader(lines[0]), "foo", "bar") + assertLogEntryContains(t, strings.NewReader(lines[0]), "msg", "quoo") + assertLogEntryContains(t, strings.NewReader(lines[0]), "level", "info") + assertLogEntryContains(t, strings.NewReader(lines[0]), "only-quoo", true) + assertLogEntryDoesNotHaveKey(t, strings.NewReader(lines[0]), "error") + + assertLogEntryContains(t, strings.NewReader(lines[1]), "foo", "bar") + assertLogEntryContains(t, strings.NewReader(lines[1]), "msg", "baz") + assertLogEntryContains(t, strings.NewReader(lines[1]), "level", "error") + assertLogEntryDoesNotHaveKey(t, strings.NewReader(lines[1]), "only-quoo") + assertLogEntryContains(t, strings.NewReader(lines[1]), "error", "only-baz") + + assertLogEntryContains(t, strings.NewReader(lines[2]), "foo", "bar") + assertLogEntryContains(t, strings.NewReader(lines[2]), "msg", "final") + assertLogEntryContains(t, strings.NewReader(lines[2]), "level", "info") + assertLogEntryDoesNotHaveKey(t, strings.NewReader(lines[2]), "only-quoo") + assertLogEntryDoesNotHaveKey(t, strings.NewReader(lines[2]), "error") +}