From 7e3f9ee1f5fe090a4ec4371834be3672716696ac Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Thu, 12 Dec 2024 09:37:36 -0500 Subject: [PATCH] Conflict Free Distributed Version Numbers (#39) --- pkg/store/lamport/lamport.go | 50 ++++++ pkg/store/lamport/lamport_test.go | 95 ++++++++++++ pkg/store/lamport/pid.go | 11 ++ pkg/store/lamport/pid_test.go | 15 ++ pkg/store/lamport/scalar.go | 172 +++++++++++++++++++++ pkg/store/lamport/scalar_test.go | 178 ++++++++++++++++++++++ pkg/store/metadata/testdata/metadata.json | 10 +- pkg/store/metadata/version.go | 35 ++--- pkg/store/object/object_test.go | 20 +-- pkg/store/object/testdata/metadata.json | 10 +- 10 files changed, 545 insertions(+), 51 deletions(-) create mode 100644 pkg/store/lamport/lamport.go create mode 100644 pkg/store/lamport/lamport_test.go create mode 100644 pkg/store/lamport/pid.go create mode 100644 pkg/store/lamport/pid_test.go create mode 100644 pkg/store/lamport/scalar.go create mode 100644 pkg/store/lamport/scalar_test.go diff --git a/pkg/store/lamport/lamport.go b/pkg/store/lamport/lamport.go new file mode 100644 index 0000000..6b4aa20 --- /dev/null +++ b/pkg/store/lamport/lamport.go @@ -0,0 +1,50 @@ +package lamport + +import "sync" + +// New returns a new clock with the specified PID. The returned clock is thread-safe +// and uses a mutex to guard against updates from multiple threads. +func New(pid uint32) Clock { + return &clock{pid: pid} +} + +// A Lamport Clock keeps track of the current process instance and monotonically +// increasing sequence timestamp. It is used to track the conflict-free distributed +// version scalar inside of a replicated system. +type Clock interface { + // Return the next timestamp using the internal process ID of the clock. + Next() Scalar + + // Update the clock with a timestamp scalar; if the scalar happens before the + // current timestamp in the clock it is ignored. + Update(Scalar) +} + +var _ Clock = &clock{} + +type clock struct { + sync.Mutex + pid uint32 + current Scalar +} + +func (c *clock) Next() Scalar { + c.Lock() + defer c.Unlock() + + c.current = Scalar{ + PID: c.pid, + VID: c.current.VID + 1, + } + + // Ensure a copy is returned so that the clock cannot be modified externally + return c.current +} + +func (c *clock) Update(now Scalar) { + c.Lock() + defer c.Unlock() + if now.After(&c.current) { + c.current = now + } +} diff --git a/pkg/store/lamport/lamport_test.go b/pkg/store/lamport/lamport_test.go new file mode 100644 index 0000000..8e86493 --- /dev/null +++ b/pkg/store/lamport/lamport_test.go @@ -0,0 +1,95 @@ +package lamport_test + +import ( + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/rotationalio/honu/pkg/store/lamport" + "github.com/stretchr/testify/require" +) + +func TestClock(t *testing.T) { + clock := lamport.New(1) + require.Equal(t, lamport.Scalar{1, 1}, clock.Next(), "expected next timestamp to be 1") + + clock.Update(lamport.Scalar{}) + require.Equal(t, lamport.Scalar{1, 2}, clock.Next(), "expected next timestamp to be 2") + + clock.Update(lamport.Scalar{2, 2}) + require.Equal(t, lamport.Scalar{1, 3}, clock.Next(), "expected next timestamp to be 3") + + clock.Update(lamport.Scalar{2, 6}) + require.Equal(t, lamport.Scalar{1, 7}, clock.Next(), "expected next timestamp to be 7") +} + +func TestClockConcurrency(t *testing.T) { + // Test concurrent clock operations by running a large number of threads with + // indpendent read and write clocks and ensure that the next version number is + // always after the previous version number when generated. + var ( + wg, rg sync.WaitGroup + failures int64 + ) + + // Create broadcast channel to simulate synchronization + C := make([]chan *lamport.Scalar, 0, 3) + + for i := 1; i < 4; i++ { + bc := make(chan *lamport.Scalar, 1) + + wg.Add(1) + go func(wg *sync.WaitGroup, c chan *lamport.Scalar) { + defer wg.Done() + defer close(c) + + var sg sync.WaitGroup + clock := lamport.New(uint32(i)) + + // Kick off the receiver routine + rg.Add(1) + go func(rg *sync.WaitGroup, c <-chan *lamport.Scalar) { + defer rg.Done() + for v := range c { + clock.Update(*v) + } + }(&rg, c) + + // Kick off several updater routines + for j := 0; j < 3; j++ { + sg.Add(1) + go func(sg *sync.WaitGroup) { + defer sg.Done() + prev := &lamport.Scalar{} + for k := 0; k < 16; k++ { + time.Sleep(time.Millisecond * time.Duration(rand.Int63n(32)+2)) + + now := clock.Next() + if !now.After(prev) { + atomic.AddInt64(&failures, 1) + return + } + + // Broadcast the update + prev = &now + for _, c := range C { + c <- prev + } + } + }(&sg) + } + + sg.Wait() + }(&wg, bc) + } + + wg.Wait() + for _, c := range C { + close(c) + } + rg.Wait() + + require.Zero(t, failures, "expected zero failures to have occurred") +} diff --git a/pkg/store/lamport/pid.go b/pkg/store/lamport/pid.go new file mode 100644 index 0000000..ba54402 --- /dev/null +++ b/pkg/store/lamport/pid.go @@ -0,0 +1,11 @@ +package lamport + +// The PID type makes it easy to create scalar versions from a previous version. +type PID uint32 + +func (p PID) Next(v *Scalar) *Scalar { + if v == nil { + return &Scalar{PID: uint32(p), VID: 1} + } + return &Scalar{PID: uint32(p), VID: v.VID + 1} +} diff --git a/pkg/store/lamport/pid_test.go b/pkg/store/lamport/pid_test.go new file mode 100644 index 0000000..0748349 --- /dev/null +++ b/pkg/store/lamport/pid_test.go @@ -0,0 +1,15 @@ +package lamport_test + +import ( + "testing" + + "github.com/rotationalio/honu/pkg/store/lamport" + "github.com/stretchr/testify/require" +) + +func TestPID(t *testing.T) { + pid := lamport.PID(42) + require.Equal(t, &lamport.Scalar{42, 1}, pid.Next(nil), "nil pid didn't return expected next") + require.Equal(t, &lamport.Scalar{42, 1}, pid.Next(&lamport.Scalar{}), "zero pid didn't return expected next") + require.Equal(t, &lamport.Scalar{42, 19}, pid.Next(&lamport.Scalar{42, 18}), "same pid didn't return expected next") +} diff --git a/pkg/store/lamport/scalar.go b/pkg/store/lamport/scalar.go new file mode 100644 index 0000000..8a9f845 --- /dev/null +++ b/pkg/store/lamport/scalar.go @@ -0,0 +1,172 @@ +package lamport + +import ( + "bytes" + "encoding" + "encoding/binary" + "errors" + "fmt" + "regexp" + "strconv" + + "github.com/rotationalio/honu/pkg/store/lani" +) + +// A lamport scalar indicates a timestamp or version in a vector clock that provides a +// "happens before" relationship between two scalars and is used to create distributed +// version numbers for eventually consistent systems and uses a "latest writer wins" +// policy to determine how to solve conflicts. +// +// Each scalar is 12 bytes and is composed of an 4 byte PID which should be unique to +// every process in the system, and an 8 byte monotonically increasing VID that +// represents the next latest version. In the case of a scalar with the same VID, the +// scalar with the larger PID happens before the scalar with the smaller PID (e.g. +// older processes, processes with smaller PIDs, win ties). +type Scalar struct { + PID uint32 + VID uint64 +} + +var ( + zero = &Scalar{0, 0} + scre = regexp.MustCompile(`^(\d+)\.(\d+)$`) + _ lani.Encodable = &Scalar{} + _ lani.Decodable = &Scalar{} + _ encoding.BinaryMarshaler = &Scalar{} + _ encoding.BinaryUnmarshaler = &Scalar{} + _ encoding.TextMarshaler = &Scalar{} + _ encoding.TextUnmarshaler = &Scalar{} +) + +const scalarSize = binary.MaxVarintLen32 + binary.MaxVarintLen64 + +// Compare returns an integer comparing two scalars using a happens before relationship. +// The result will be 0 if a == b, -1 if a < b (e.g. a happens before b), and +// +1 if a > b (e.g. b happens before a). A nil argument is equivalent to a zero scalar. +func Compare(a, b *Scalar) int { + if a == nil && b == nil { + return 0 + } + + if a == nil { + a = zero + } + + if b == nil { + b = zero + } + + if a.VID == b.VID { + switch { + case a.PID < b.PID: + return -1 + case a.PID > b.PID: + return 1 + default: + return 0 + } + } + + if a.VID > b.VID { + return 1 + } + return -1 +} + +// Returns true if the scalar is the zero-valued scalar (0.0) +func (s *Scalar) IsZero() bool { + return s.PID == 0 && s.VID == 0 +} + +// Returns true if the scalar is equal to the input scalar. +func (s *Scalar) Equals(o *Scalar) bool { + return Compare(s, o) == 0 +} + +// Returns true if the scalar is less than the input scalar (e.g. this scalar happens +// before the input scalar). +func (s *Scalar) Before(o *Scalar) bool { + return Compare(s, o) < 0 +} + +// Returns true if the scalar is grater than the input scalar (e.g. the input scalar +// hapens before this scalar). +func (s *Scalar) After(o *Scalar) bool { + return Compare(s, o) > 0 +} + +func (s *Scalar) Size() int { + return scalarSize +} + +func (s *Scalar) Encode(e *lani.Encoder) (n int, err error) { + var m int + if m, err = e.EncodeUint32(s.PID); err != nil { + return n, err + } + n += m + + if m, err = e.EncodeUint64(s.VID); err != nil { + return n, err + } + n += m + + return n, nil +} + +func (s *Scalar) Decode(d *lani.Decoder) (err error) { + if s.PID, err = d.DecodeUint32(); err != nil { + return err + } + + if s.VID, err = d.DecodeUint64(); err != nil { + return err + } + + return nil +} + +func (s *Scalar) MarshalBinary() ([]byte, error) { + e := &lani.Encoder{} + e.Grow(s.Size()) + + if _, err := s.Encode(e); err != nil { + return nil, err + } + + return e.Bytes(), nil +} + +func (s *Scalar) UnmarshalBinary(data []byte) (err error) { + d := lani.NewDecoder(data) + return s.Decode(d) +} + +func (s *Scalar) MarshalText() (_ []byte, err error) { + return []byte(s.String()), nil +} + +func (s *Scalar) UnmarshalText(text []byte) (err error) { + if !scre.Match(text) { + return errors.New("could not parse text representation of a scalar") + } + + parts := bytes.Split(text, []byte{'.'}) + + var pid uint64 + if pid, err = strconv.ParseUint(string(parts[0]), 10, 32); err != nil { + panic("pid is not parseable even though regular expression matched.") + } + s.PID = uint32(pid) + + if s.VID, err = strconv.ParseUint(string(parts[1]), 10, 32); err != nil { + panic("pid is not parseable even though regular expression matched.") + } + + return nil +} + +// Returns a scalar version representation in the form PID.VID using decimal notation. +func (s *Scalar) String() string { + return fmt.Sprintf("%d.%d", s.PID, s.VID) +} diff --git a/pkg/store/lamport/scalar_test.go b/pkg/store/lamport/scalar_test.go new file mode 100644 index 0000000..89cc2a4 --- /dev/null +++ b/pkg/store/lamport/scalar_test.go @@ -0,0 +1,178 @@ +package lamport_test + +import ( + "encoding/json" + "math/rand/v2" + "testing" + + . "github.com/rotationalio/honu/pkg/store/lamport" + "github.com/stretchr/testify/require" +) + +func TestScalar(t *testing.T) { + zero := &Scalar{} + one := &Scalar{1, 1} + + t.Run("IsZero", func(t *testing.T) { + require.True(t, zero.IsZero(), "empty scalar is not zero") + require.False(t, one.IsZero(), "1.1 should not be zero") + require.False(t, (&Scalar{1, 0}).IsZero(), "1.0 should not be zero") + require.False(t, (&Scalar{0, 1}).IsZero(), "0.1 should not be zero") + }) + + t.Run("Serialize", func(t *testing.T) { + current := &Scalar{} + for i := 0; i < 128; i++ { + data, err := current.MarshalBinary() + require.NoError(t, err, "could not marshal %s", current) + + cmpr := &Scalar{} + err = cmpr.UnmarshalBinary(data) + require.NoError(t, err, "could not unmarshal %d bytes", len(data)) + + require.Equal(t, current, cmpr, "unmarshaled scalar does not match marshaled one") + + current = randNextScalar(current) + } + }) + + t.Run("Text", func(t *testing.T) { + current := &Scalar{} + for i := 0; i < 128; i++ { + data, err := current.MarshalText() + require.NoError(t, err, "could not marshal %s", current) + + cmpr := &Scalar{} + err = cmpr.UnmarshalText(data) + require.NoError(t, err, "could not unmarshal %d bytes", len(data)) + + require.Equal(t, current, cmpr, "unmarshaled scalar does not match marshaled one") + + current = randNextScalar(current) + } + }) + + t.Run("BadText", func(t *testing.T) { + testCases := []string{ + "123", + "a.b", + "a.123", + "1.abc", + "", + "1.1.1", + "1.", + ".1", + } + + for i, tc := range testCases { + err := (&Scalar{}).UnmarshalText([]byte(tc)) + require.Error(t, err, "expected errror on test case %d", i) + } + }) + + t.Run("Binary", func(t *testing.T) { + vers := &Scalar{42, 198} + data, err := vers.MarshalBinary() + require.NoError(t, err, "could not marshal scalar as a binary value") + require.Equal(t, []byte{0x2a, 0xc6, 0x1}, data) + }) + + t.Run("JSON", func(t *testing.T) { + s := &Scalar{} + data := []byte(`"8.16"`) + + require.NoError(t, json.Unmarshal(data, s), "could not unmarshal s") + require.Equal(t, &Scalar{8, 16}, s, "incorrect unmarshal") + + cmpt, err := json.Marshal(s) + require.NoError(t, err, "could not marshal s") + require.Equal(t, data, cmpt, "unexpected marshaled data") + }) +} + +func TestCompare(t *testing.T) { + testCases := []struct { + a, b *Scalar + expected int + }{ + { + nil, nil, 0, + }, + { + &Scalar{}, &Scalar{}, 0, + }, + { + nil, &Scalar{}, 0, + }, + { + &Scalar{}, nil, 0, + }, + { + &Scalar{1, 1}, &Scalar{}, 1, + }, + { + &Scalar{}, &Scalar{1, 1}, -1, + }, + { + &Scalar{1, 2}, &Scalar{1, 1}, 1, + }, + { + &Scalar{1, 1}, &Scalar{1, 2}, -1, + }, + { + &Scalar{1, 1}, &Scalar{1, 1}, 0, + }, + { + &Scalar{2, 1}, &Scalar{1, 1}, 1, + }, + { + &Scalar{1, 1}, &Scalar{2, 1}, -1, + }, + { + &Scalar{2, 2}, &Scalar{2, 2}, 0, + }, + } + + for i, tc := range testCases { + require.Equal(t, tc.expected, Compare(tc.a, tc.b), "test case %d failed", i) + + // Handle semantic comparisons + switch tc.expected { + case 0: + require.True(t, tc.a.Equals(tc.b), "test case %d equality failed", i) + require.True(t, tc.b.Equals(tc.a), "test case %d equality failed", i) + + require.False(t, tc.a.Before(tc.b), "test case %d before failed", i) + require.False(t, tc.b.Before(tc.a), "test case %d before failed", i) + + require.False(t, tc.a.After(tc.b), "test case %d after failed", i) + require.False(t, tc.b.After(tc.a), "test case %d after failed", i) + case 1: + require.False(t, tc.a.Equals(tc.b), "test case %d equality failed", i) + require.False(t, tc.b.Equals(tc.a), "test case %d equality failed", i) + + require.False(t, tc.a.Before(tc.b), "test case %d before failed", i) + require.True(t, tc.b.Before(tc.a), "test case %d before failed", i) + + require.True(t, tc.a.After(tc.b), "test case %d after failed", i) + require.False(t, tc.b.After(tc.a), "test case %d after failed", i) + case -1: + require.False(t, tc.a.Equals(tc.b), "test case %d equality failed", i) + require.False(t, tc.b.Equals(tc.a), "test case %d equality failed", i) + + require.True(t, tc.a.Before(tc.b), "test case %d before failed", i) + require.False(t, tc.b.Before(tc.a), "test case %d before failed", i) + + require.False(t, tc.a.After(tc.b), "test case %d after failed", i) + require.True(t, tc.b.After(tc.a), "test case %d after failed", i) + } + + } +} + +func randNextScalar(prev *Scalar) *Scalar { + s := &Scalar{} + s.PID = uint32(rand.Int32N(24)) + s.VID = uint64(rand.Int64N(32)) + prev.VID + return s +} diff --git a/pkg/store/metadata/testdata/metadata.json b/pkg/store/metadata/testdata/metadata.json index 8aea3fd..9efc2d0 100644 --- a/pkg/store/metadata/testdata/metadata.json +++ b/pkg/store/metadata/testdata/metadata.json @@ -1,14 +1,8 @@ { "version": { - "pid": 8, - "version": 12, + "scalar": "8.12", "region": "us-central1", - "parent": { - "pid": 3, - "version": 11, - "region": "us-central1", - "created": "2024-11-30T10:29:59Z" - }, + "parent": "3.11", "created": "2024-11-30T10:29:59Z" }, "schema": { diff --git a/pkg/store/metadata/version.go b/pkg/store/metadata/version.go index e9b7245..f23be3a 100644 --- a/pkg/store/metadata/version.go +++ b/pkg/store/metadata/version.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "time" + "github.com/rotationalio/honu/pkg/store/lamport" "github.com/rotationalio/honu/pkg/store/lani" ) @@ -12,25 +13,22 @@ import ( //=========================================================================== type Version struct { - PID uint64 `json:"pid" msg:"pid"` - Version uint64 `json:"version" msg:"version"` - Region string `json:"region" msg:"region"` - Parent *Version `json:"parent,omitempty" msg:"parent,omitempty"` - Tombstone bool `json:"tombstone,omitempty" msg:"tombstone,omitempty"` - Created time.Time `json:"created" msg:"created"` + Scalar lamport.Scalar `json:"scalar" msg:"scalar"` + Region string `json:"region" msg:"region"` + Parent *lamport.Scalar `json:"parent,omitempty" msg:"parent,omitempty"` + Tombstone bool `json:"tombstone,omitempty" msg:"tombstone,omitempty"` + Created time.Time `json:"created" msg:"created"` } var _ lani.Encodable = &Version{} var _ lani.Decodable = &Version{} func (o *Version) Size() (s int) { - s += 2 * binary.MaxVarintLen64 - s += len([]byte(o.Region)) + binary.MaxVarintLen64 + s += o.Scalar.Size() // Scalar uint32 + uint64 + s += 1 // Add 1 for the parent nil bool if o.Parent != nil { - s += o.Parent.Size() + 1 // Add 1 for the not nil bool - } else { - s += 1 // Add 1 for the nil bool + s += o.Parent.Size() } s += 1 // Tombstone bool @@ -41,12 +39,7 @@ func (o *Version) Size() (s int) { func (o *Version) Encode(e *lani.Encoder) (n int, err error) { var m int - if m, err = e.EncodeUint64(o.PID); err != nil { - return n + m, err - } - n += m - - if m, err = e.EncodeUint64(o.Version); err != nil { + if m, err = o.Scalar.Encode(e); err != nil { return n + m, err } n += m @@ -75,11 +68,7 @@ func (o *Version) Encode(e *lani.Encoder) (n int, err error) { } func (o *Version) Decode(d *lani.Decoder) (err error) { - if o.PID, err = d.DecodeUint64(); err != nil { - return err - } - - if o.Version, err = d.DecodeUint64(); err != nil { + if err = o.Scalar.Decode(d); err != nil { return err } @@ -88,7 +77,7 @@ func (o *Version) Decode(d *lani.Decoder) (err error) { } var isNil bool - o.Parent = &Version{} + o.Parent = &lamport.Scalar{} if isNil, err = d.DecodeStruct(o.Parent); err != nil { return err diff --git a/pkg/store/object/object_test.go b/pkg/store/object/object_test.go index a619c35..1d7ab82 100644 --- a/pkg/store/object/object_test.go +++ b/pkg/store/object/object_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/oklog/ulid/v2" + "github.com/rotationalio/honu/pkg/store/lamport" "github.com/rotationalio/honu/pkg/store/metadata" "github.com/rotationalio/honu/pkg/store/object" "github.com/stretchr/testify/require" @@ -22,7 +23,7 @@ func TestObject(t *testing.T) { obj, err := object.Marshal(meta, data) require.NoError(t, err, "could not marshal object") - require.Len(t, obj, 1271, "unexpected length of encoded object") + require.Len(t, obj, 1248, "unexpected length of encoded object") require.Equal(t, object.StorageVersion, obj.StorageVersion()) ometa, err := obj.Metadata() @@ -141,7 +142,7 @@ func BenchmarkSerialization(b *testing.B) { func generateRandomObject(size Size) (*metadata.Metadata, []byte) { obj := &metadata.Metadata{ - Version: randVersion(false), + Version: randVersion(), Schema: randSchema(), MIME: "application/random", Owner: ulid.MustNew(ulid.Now(), rand.Reader), @@ -165,22 +166,17 @@ func generateRandomObject(size Size) (*metadata.Metadata, []byte) { return obj, data } -func randVersion(isParent bool) *metadata.Version { - // 10% chance of nil - if mrand.Float32() < 0.1 { - return nil - } - +func randVersion() *metadata.Version { vers := &metadata.Version{ - PID: mrand.Uint64(), - Version: mrand.Uint64(), + Scalar: lamport.Scalar{PID: mrand.Uint32(), VID: mrand.Uint64()}, Region: randRegion(), Tombstone: mrand.Float32() < 0.25, Created: randTime(), } - if !isParent { - vers.Parent = randVersion(true) + // 10% chance of nil parent + if mrand.Float32() < 0.9 { + vers.Parent = &lamport.Scalar{PID: mrand.Uint32(), VID: mrand.Uint64()} } return vers diff --git a/pkg/store/object/testdata/metadata.json b/pkg/store/object/testdata/metadata.json index 31dbf33..35c750a 100644 --- a/pkg/store/object/testdata/metadata.json +++ b/pkg/store/object/testdata/metadata.json @@ -1,14 +1,8 @@ { "version": { - "pid": 8, - "version": 12, + "scalar": "8.12", "region": "us-central1", - "parent": { - "pid": 3, - "version": 11, - "region": "us-central1", - "created": "2024-11-30T10:29:59Z" - }, + "parent": "3.11", "created": "2024-11-30T10:29:59Z" }, "schema": {