Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the package safe for concurrent access #284

Merged
merged 7 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 70 additions & 17 deletions field/composite.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"reflect"
"regexp"
"strconv"
"sync"

"github.com/moov-io/iso8583/encoding"
"github.com/moov-io/iso8583/prefix"
Expand Down Expand Up @@ -63,11 +64,15 @@ var _ json.Unmarshaler = (*Composite)(nil)
// For the sake of determinism, packing of subfields is executed in order of Tag
// (using Spec.Tag.Sort) regardless of the value of Spec.Tag.Length.
type Composite struct {
spec *Spec
bitmap *Bitmap
spec *Spec
cachedBitmap *Bitmap

orderedSpecFieldTags []string

// mu is used to synchronize access to the subfields and
// setSubfields maps when the composite is used concurrently
mu sync.Mutex

// stores all fields according to the spec
subfields map[string]Field

Expand All @@ -92,7 +97,13 @@ type CompositeWithSubfields interface {
ConstructSubfields()
}

// ConstructSubfields creates subfields according to the spec
// this method is used when composite field is created without
// calling NewComposite (when we create message spec and composite spec)
func (f *Composite) ConstructSubfields() {
f.mu.Lock()
defer f.mu.Unlock()

if f.subfields == nil {
f.subfields = CreateSubfields(f.spec)
}
Expand All @@ -106,6 +117,15 @@ func (f *Composite) Spec() *Spec {

// GetSubfields returns the map of set sub fields
func (f *Composite) GetSubfields() map[string]Field {
f.mu.Lock()
defer f.mu.Unlock()

return f.getSubfields()
}

// getSubfields returns the map of set sub fields, it should be called
// only when the mutex is locked
func (f *Composite) getSubfields() map[string]Field {
fields := map[string]Field{}
for i := range f.setSubfields {
fields[i] = f.subfields[i]
Expand All @@ -120,7 +140,7 @@ func (f *Composite) GetSubfields() map[string]Field {
// will result in a panic.
func (f *Composite) SetSpec(spec *Spec) {
if err := spec.Validate(); err != nil {
panic(err) //nolint // as specs moslty static, we panic on spec validation errors
panic(err) //nolint:forbidigo,nolintlint // as specs moslty static, we panic on spec validation errors
}
f.spec = spec

Expand All @@ -137,6 +157,9 @@ func (f *Composite) SetSpec(spec *Spec) {
}

func (f *Composite) Unmarshal(v interface{}) error {
f.mu.Lock()
defer f.mu.Unlock()

rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return errors.New("data is not a pointer or nil")
Expand Down Expand Up @@ -203,6 +226,9 @@ func (f *Composite) SetData(v interface{}) error {
// F4 *SubfieldCompositeData
// }
func (f *Composite) Marshal(v interface{}) error {
f.mu.Lock()
defer f.mu.Unlock()

rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return errors.New("data is not a pointer or nil")
Expand Down Expand Up @@ -251,6 +277,9 @@ func (f *Composite) Marshal(v interface{}) error {
// Pack deserialises data held by the receiver (via SetData)
// into bytes and returns an error on failure.
func (f *Composite) Pack() ([]byte, error) {
f.mu.Lock()
defer f.mu.Unlock()

packed, err := f.pack()
if err != nil {
return nil, err
Expand All @@ -268,6 +297,9 @@ func (f *Composite) Pack() ([]byte, error) {
// subfields. An offset (unit depends on encoding and prefix values) is
// returned on success. A non-nil error is returned on failure.
func (f *Composite) Unpack(data []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()

dataLen, offset, err := f.spec.Pref.DecodeLength(f.spec.Length, data)
if err != nil {
return 0, fmt.Errorf("failed to decode length: %w", err)
Expand Down Expand Up @@ -300,6 +332,9 @@ func (f *Composite) Unpack(data []byte) (int, error) {
// pack all subfields in full. However, unlike Unpack(), it requires the
// aggregate length of the subfields not to be encoded in the prefix.
func (f *Composite) SetBytes(data []byte) error {
f.mu.Lock()
defer f.mu.Unlock()

_, err := f.unpack(data, false)
return err
}
Expand All @@ -308,14 +343,26 @@ func (f *Composite) SetBytes(data []byte) error {
// does not incorporate the encoded aggregate length of the subfields in the
// prefix.
func (f *Composite) Bytes() ([]byte, error) {
f.mu.Lock()
defer f.mu.Unlock()

return f.pack()
}

// Bitmap returns the parsed bitmap instantiated on the key "0" of the spec.
// In case the bitmap is not instantiated on the spec, returns nil.
func (f *Composite) Bitmap() *Bitmap {
if f.bitmap != nil {
return f.bitmap
f.mu.Lock()
defer f.mu.Unlock()

return f.bitmap()
}

func (f *Composite) bitmap() *Bitmap {
// TODO: protect against concurrent access
alovak marked this conversation as resolved.
Show resolved Hide resolved

if f.cachedBitmap != nil {
return f.cachedBitmap
}

if f.spec.Bitmap == nil {
Expand All @@ -327,9 +374,9 @@ func (f *Composite) Bitmap() *Bitmap {
return nil
}

f.bitmap = bitmap
f.cachedBitmap = bitmap

return f.bitmap
return f.cachedBitmap
}

// String iterates over the receiver's subfields, packs them and converts the
Expand All @@ -345,7 +392,10 @@ func (f *Composite) String() (string, error) {

// MarshalJSON implements the encoding/json.Marshaler interface.
func (f *Composite) MarshalJSON() ([]byte, error) {
jsonData := OrderedMap(f.GetSubfields())
f.mu.Lock()
defer f.mu.Unlock()

jsonData := OrderedMap(f.getSubfields())
bytes, err := json.Marshal(jsonData)
if err != nil {
return nil, utils.NewSafeError(err, "failed to JSON marshal map to bytes")
Expand All @@ -357,6 +407,9 @@ func (f *Composite) MarshalJSON() ([]byte, error) {
// An error is thrown if the JSON consists of a subfield that has not
// been defined in the spec.
func (f *Composite) UnmarshalJSON(b []byte) error {
f.mu.Lock()
defer f.mu.Unlock()

var data map[string]json.RawMessage
err := json.Unmarshal(b, &data)
if err != nil {
Expand Down Expand Up @@ -384,15 +437,15 @@ func (f *Composite) UnmarshalJSON(b []byte) error {
}

func (f *Composite) pack() ([]byte, error) {
if f.Bitmap() != nil {
if f.bitmap() != nil {
return f.packByBitmap()
}

return f.packByTag()
}

func (f *Composite) packByBitmap() ([]byte, error) {
f.Bitmap().Reset()
f.bitmap().Reset()

var packedFields []byte

Expand All @@ -409,7 +462,7 @@ func (f *Composite) packByBitmap() ([]byte, error) {
}

// set bitmap bit for this field
f.Bitmap().Set(idInt)
f.bitmap().Set(idInt)

field, ok := f.subfields[id]
if !ok {
Expand All @@ -425,7 +478,7 @@ func (f *Composite) packByBitmap() ([]byte, error) {
}

// pack bitmap.
packedBitmap, err := f.Bitmap().Pack()
packedBitmap, err := f.bitmap().Pack()
if err != nil {
return nil, fmt.Errorf("packing bitmap: %w", err)
}
Expand Down Expand Up @@ -469,7 +522,7 @@ func (f *Composite) packByTag() ([]byte, error) {
}

func (f *Composite) unpack(data []byte, isVariableLength bool) (int, error) {
if f.Bitmap() != nil {
if f.bitmap() != nil {
return f.unpackSubfieldsByBitmap(data)
}
if f.spec.Tag.Enc != nil {
Expand Down Expand Up @@ -509,17 +562,17 @@ func (f *Composite) unpackSubfieldsByBitmap(data []byte) (int, error) {
// Reset fields that were set.
f.setSubfields = make(map[string]struct{})

f.Bitmap().Reset()
f.bitmap().Reset()

read, err := f.Bitmap().Unpack(data[off:])
read, err := f.bitmap().Unpack(data[off:])
if err != nil {
return 0, fmt.Errorf("failed to unpack bitmap: %w", err)
}

off += read

for i := 1; i <= f.Bitmap().Len(); i++ {
if f.Bitmap().IsSet(i) {
for i := 1; i <= f.bitmap().Len(); i++ {
if f.bitmap().IsSet(i) {
iStr := strconv.Itoa(i)

fl, ok := f.subfields[iStr]
Expand Down
59 changes: 59 additions & 0 deletions field/composite_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package field

import (
"encoding/hex"
"fmt"
"reflect"
"strconv"
"sync"
"testing"

"github.com/moov-io/iso8583/encoding"
Expand Down Expand Up @@ -1892,3 +1894,60 @@ func TestComposite_getFieldIndexOrTag(t *testing.T) {
require.Empty(t, index)
})
}

func TestComposit_concurrency(t *testing.T) {
t.Run("Pack and Marshal", func(t *testing.T) {
// packing and marshaling
data := &TLVTestData{
F9A: NewHexValue("210720"),
F9F02: NewHexValue("000000000501"),
}

composite := NewComposite(tlvTestSpec)

wg := sync.WaitGroup{}
wg.Add(5)
alovak marked this conversation as resolved.
Show resolved Hide resolved

for i := 0; i < 5; i++ {
go func() {
defer wg.Done()

err := composite.Marshal(data)
require.NoError(t, err)

_, err = composite.Pack()
require.NoError(t, err)
}()
}

wg.Wait()
})

t.Run("Unpack and Unmarshal", func(t *testing.T) {
packed, err := hex.DecodeString("3031349A032107209F0206000000000501")
require.NoError(t, err)

composite := NewComposite(tlvTestSpec)

wg := sync.WaitGroup{}
wg.Add(5)

for i := 0; i < 5; i++ {
go func() {
defer wg.Done()

data := &TLVTestData{}
_, err := composite.Unpack(packed)
require.NoError(t, err)

err = composite.Unmarshal(data)
require.NoError(t, err)

require.Equal(t, "210720", data.F9A.Value())
require.Equal(t, "000000000501", data.F9F02.Value())
}()
}

wg.Wait()
})
}
Loading
Loading