diff --git a/message.go b/message.go index 51c7a47..980e24d 100644 --- a/message.go +++ b/message.go @@ -8,6 +8,7 @@ import ( "regexp" "sort" "strconv" + "sync" "github.com/moov-io/iso8583/field" "github.com/moov-io/iso8583/utils" @@ -28,6 +29,9 @@ type Message struct { // stores all fields according to the spec fields map[int]field.Field + // to guard fieldsMap + mu sync.Mutex + // tracks which fields were set fieldsMap map[int]struct{} } @@ -62,13 +66,19 @@ func (m *Message) Bitmap() *field.Bitmap { // it exists and is of the correct type m.bitmap, _ = m.fields[bitmapIdx].(*field.Bitmap) m.bitmap.Reset() + + m.mu.Lock() m.fieldsMap[bitmapIdx] = struct{}{} + m.mu.Unlock() return m.bitmap } func (m *Message) MTI(val string) { + m.mu.Lock() m.fieldsMap[mtiIdx] = struct{}{} + m.mu.Unlock() + m.fields[mtiIdx].SetBytes([]byte(val)) } @@ -78,7 +88,9 @@ func (m *Message) GetSpec() *MessageSpec { func (m *Message) Field(id int, val string) error { if f, ok := m.fields[id]; ok { + m.mu.Lock() m.fieldsMap[id] = struct{}{} + m.mu.Unlock() return f.SetBytes([]byte(val)) } return fmt.Errorf("failed to set field %d. ID does not exist", id) @@ -86,7 +98,10 @@ func (m *Message) Field(id int, val string) error { func (m *Message) BinaryField(id int, val []byte) error { if f, ok := m.fields[id]; ok { + m.mu.Lock() m.fieldsMap[id] = struct{}{} + m.mu.Unlock() + return f.SetBytes(val) } return fmt.Errorf("failed to set binary field %d. ID does not exist", id) @@ -99,7 +114,10 @@ func (m *Message) GetMTI() (string, error) { func (m *Message) GetString(id int) (string, error) { if f, ok := m.fields[id]; ok { + m.mu.Lock() m.fieldsMap[id] = struct{}{} + m.mu.Unlock() + return f.String() } return "", fmt.Errorf("failed to get string for field %d. ID does not exist", id) @@ -107,7 +125,10 @@ func (m *Message) GetString(id int) (string, error) { func (m *Message) GetBytes(id int) ([]byte, error) { if f, ok := m.fields[id]; ok { + m.mu.Lock() m.fieldsMap[id] = struct{}{} + m.mu.Unlock() + return f.Bytes() } return nil, fmt.Errorf("failed to get bytes for field %d. ID does not exist", id) @@ -119,6 +140,9 @@ func (m *Message) GetField(id int) field.Field { // Fields returns the map of the set fields func (m *Message) GetFields() map[int]field.Field { + m.mu.Lock() + defer m.mu.Unlock() + fields := map[int]field.Field{} for i := range m.fieldsMap { fields[i] = m.GetField(i) @@ -192,11 +216,19 @@ func (m *Message) Unpack(src []byte) error { func (m *Message) unpack(src []byte) error { var off int + m.mu.Lock() // reset fields that were set m.fieldsMap = map[int]struct{}{} + // we unlock here as m.Bitmap() will lock the mutex again + m.mu.Unlock() + bitmap := m.Bitmap() // This method implicitly also sets m.fieldsMap[bitmapIdx] - m.Bitmap().Reset() + bitmap.Reset() + + // lock the mutex again as we're going to set fields + m.mu.Lock() + defer m.mu.Unlock() read, err := m.fields[mtiIdx].Unpack(src) if err != nil { @@ -215,13 +247,13 @@ func (m *Message) unpack(src []byte) error { off += read - for i := 2; i <= m.Bitmap().Len(); i++ { + for i := 2; i <= bitmap.Len(); i++ { // skip bitmap presence bits (for default bitmap length of 64 these are bits 1, 65, 129, 193, etc.) - if m.Bitmap().IsBitmapPresenceBit(i) { + if bitmap.IsBitmapPresenceBit(i) { continue } - if m.Bitmap().IsSet(i) { + if bitmap.IsSet(i) { fl, ok := m.fields[i] if !ok { return fmt.Errorf("failed to unpack field %d: no specification found", i) @@ -264,6 +296,9 @@ func (m *Message) MarshalJSON() ([]byte, error) { } func (m *Message) UnmarshalJSON(b []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + var data map[string]json.RawMessage if err := json.Unmarshal(b, &data); err != nil { return err @@ -291,6 +326,9 @@ func (m *Message) UnmarshalJSON(b []byte) error { } func (m *Message) packableFieldIDs() ([]int, error) { + m.mu.Lock() + defer m.mu.Unlock() + // Index 1 represent bitmap which is always populated. populatedFieldIDs := []int{1} @@ -338,6 +376,9 @@ func (m *Message) Clone() (*Message, error) { // through the message fields and calls Unmarshal(...) on them setting the v If // v is not a struct or not a pointer to struct then it returns error. func (m *Message) Marshal(v interface{}) error { + m.mu.Lock() + defer m.mu.Unlock() + if v == nil { return nil } @@ -391,6 +432,9 @@ func (m *Message) Marshal(v interface{}) error { // through the message fields and calls Unmarshal(...) on them setting the v If // v is nil or not a pointer it returns error. func (m *Message) Unmarshal(v interface{}) error { + m.mu.Lock() + defer m.mu.Unlock() + rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { return errors.New("data is not a pointer or nil") diff --git a/message_test.go b/message_test.go index 2989d4c..2eb4868 100644 --- a/message_test.go +++ b/message_test.go @@ -3,7 +3,10 @@ package iso8583 import ( "encoding/hex" "encoding/json" + "log" + "net/http" "reflect" + "sync" "testing" "time" @@ -14,9 +17,15 @@ import ( "github.com/moov-io/iso8583/sort" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + _ "net/http/pprof" ) func TestMessage(t *testing.T) { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + spec := &MessageSpec{ Fields: map[int]field.Field{ 0: field.NewString(&field.Spec{ @@ -90,6 +99,27 @@ func TestMessage(t *testing.T) { }, } + // this test most probably will fail in regular mode, + // and should fail when is run with -race flag + t.Run("No data race when accessing fields concurrently", func(t *testing.T) { + message := NewMessage(spec) + + var wg sync.WaitGroup + + for i := 0; i < 1000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + // calling GetString writes into the map of the + // set fields + message.GetString(0) + }() + } + + wg.Wait() + }) + t.Run("Test packing and unpacking untyped fields", func(t *testing.T) { message := NewMessage(spec) message.MTI("0100")