Skip to content

Commit

Permalink
Conflict Free Distributed Version Numbers (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbengfort authored Dec 12, 2024
1 parent 1e07f9c commit 7e3f9ee
Show file tree
Hide file tree
Showing 10 changed files with 545 additions and 51 deletions.
50 changes: 50 additions & 0 deletions pkg/store/lamport/lamport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package lamport

import "sync"

// New returns a new clock with the specified PID. The returned clock is thread-safe
// and uses a mutex to guard against updates from multiple threads.
func New(pid uint32) Clock {
return &clock{pid: pid}
}

// A Lamport Clock keeps track of the current process instance and monotonically
// increasing sequence timestamp. It is used to track the conflict-free distributed
// version scalar inside of a replicated system.
type Clock interface {
// Return the next timestamp using the internal process ID of the clock.
Next() Scalar

// Update the clock with a timestamp scalar; if the scalar happens before the
// current timestamp in the clock it is ignored.
Update(Scalar)
}

var _ Clock = &clock{}

type clock struct {
sync.Mutex
pid uint32
current Scalar
}

func (c *clock) Next() Scalar {
c.Lock()
defer c.Unlock()

c.current = Scalar{
PID: c.pid,
VID: c.current.VID + 1,
}

// Ensure a copy is returned so that the clock cannot be modified externally
return c.current
}

func (c *clock) Update(now Scalar) {
c.Lock()
defer c.Unlock()
if now.After(&c.current) {
c.current = now
}
}
95 changes: 95 additions & 0 deletions pkg/store/lamport/lamport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package lamport_test

import (
"math/rand"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/rotationalio/honu/pkg/store/lamport"
"github.com/stretchr/testify/require"
)

func TestClock(t *testing.T) {
clock := lamport.New(1)
require.Equal(t, lamport.Scalar{1, 1}, clock.Next(), "expected next timestamp to be 1")

clock.Update(lamport.Scalar{})
require.Equal(t, lamport.Scalar{1, 2}, clock.Next(), "expected next timestamp to be 2")

clock.Update(lamport.Scalar{2, 2})
require.Equal(t, lamport.Scalar{1, 3}, clock.Next(), "expected next timestamp to be 3")

clock.Update(lamport.Scalar{2, 6})
require.Equal(t, lamport.Scalar{1, 7}, clock.Next(), "expected next timestamp to be 7")
}

func TestClockConcurrency(t *testing.T) {
// Test concurrent clock operations by running a large number of threads with
// indpendent read and write clocks and ensure that the next version number is
// always after the previous version number when generated.
var (
wg, rg sync.WaitGroup
failures int64
)

// Create broadcast channel to simulate synchronization
C := make([]chan *lamport.Scalar, 0, 3)

for i := 1; i < 4; i++ {
bc := make(chan *lamport.Scalar, 1)

wg.Add(1)
go func(wg *sync.WaitGroup, c chan *lamport.Scalar) {
defer wg.Done()
defer close(c)

var sg sync.WaitGroup
clock := lamport.New(uint32(i))

// Kick off the receiver routine
rg.Add(1)
go func(rg *sync.WaitGroup, c <-chan *lamport.Scalar) {
defer rg.Done()
for v := range c {
clock.Update(*v)
}
}(&rg, c)

// Kick off several updater routines
for j := 0; j < 3; j++ {
sg.Add(1)
go func(sg *sync.WaitGroup) {
defer sg.Done()
prev := &lamport.Scalar{}
for k := 0; k < 16; k++ {
time.Sleep(time.Millisecond * time.Duration(rand.Int63n(32)+2))

now := clock.Next()
if !now.After(prev) {
atomic.AddInt64(&failures, 1)
return
}

// Broadcast the update
prev = &now
for _, c := range C {
c <- prev
}
}
}(&sg)
}

sg.Wait()
}(&wg, bc)
}

wg.Wait()
for _, c := range C {
close(c)
}
rg.Wait()

require.Zero(t, failures, "expected zero failures to have occurred")
}
11 changes: 11 additions & 0 deletions pkg/store/lamport/pid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package lamport

// The PID type makes it easy to create scalar versions from a previous version.
type PID uint32

func (p PID) Next(v *Scalar) *Scalar {
if v == nil {
return &Scalar{PID: uint32(p), VID: 1}
}
return &Scalar{PID: uint32(p), VID: v.VID + 1}
}
15 changes: 15 additions & 0 deletions pkg/store/lamport/pid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package lamport_test

import (
"testing"

"github.com/rotationalio/honu/pkg/store/lamport"
"github.com/stretchr/testify/require"
)

func TestPID(t *testing.T) {
pid := lamport.PID(42)
require.Equal(t, &lamport.Scalar{42, 1}, pid.Next(nil), "nil pid didn't return expected next")
require.Equal(t, &lamport.Scalar{42, 1}, pid.Next(&lamport.Scalar{}), "zero pid didn't return expected next")
require.Equal(t, &lamport.Scalar{42, 19}, pid.Next(&lamport.Scalar{42, 18}), "same pid didn't return expected next")
}
172 changes: 172 additions & 0 deletions pkg/store/lamport/scalar.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package lamport

import (
"bytes"
"encoding"
"encoding/binary"
"errors"
"fmt"
"regexp"
"strconv"

"github.com/rotationalio/honu/pkg/store/lani"
)

// A lamport scalar indicates a timestamp or version in a vector clock that provides a
// "happens before" relationship between two scalars and is used to create distributed
// version numbers for eventually consistent systems and uses a "latest writer wins"
// policy to determine how to solve conflicts.
//
// Each scalar is 12 bytes and is composed of an 4 byte PID which should be unique to
// every process in the system, and an 8 byte monotonically increasing VID that
// represents the next latest version. In the case of a scalar with the same VID, the
// scalar with the larger PID happens before the scalar with the smaller PID (e.g.
// older processes, processes with smaller PIDs, win ties).
type Scalar struct {
PID uint32
VID uint64
}

var (
zero = &Scalar{0, 0}
scre = regexp.MustCompile(`^(\d+)\.(\d+)$`)
_ lani.Encodable = &Scalar{}
_ lani.Decodable = &Scalar{}
_ encoding.BinaryMarshaler = &Scalar{}
_ encoding.BinaryUnmarshaler = &Scalar{}
_ encoding.TextMarshaler = &Scalar{}
_ encoding.TextUnmarshaler = &Scalar{}
)

const scalarSize = binary.MaxVarintLen32 + binary.MaxVarintLen64

// Compare returns an integer comparing two scalars using a happens before relationship.
// The result will be 0 if a == b, -1 if a < b (e.g. a happens before b), and
// +1 if a > b (e.g. b happens before a). A nil argument is equivalent to a zero scalar.
func Compare(a, b *Scalar) int {
if a == nil && b == nil {
return 0
}

if a == nil {
a = zero
}

if b == nil {
b = zero
}

if a.VID == b.VID {
switch {
case a.PID < b.PID:
return -1
case a.PID > b.PID:
return 1
default:
return 0
}
}

if a.VID > b.VID {
return 1
}
return -1
}

// Returns true if the scalar is the zero-valued scalar (0.0)
func (s *Scalar) IsZero() bool {
return s.PID == 0 && s.VID == 0
}

// Returns true if the scalar is equal to the input scalar.
func (s *Scalar) Equals(o *Scalar) bool {
return Compare(s, o) == 0
}

// Returns true if the scalar is less than the input scalar (e.g. this scalar happens
// before the input scalar).
func (s *Scalar) Before(o *Scalar) bool {
return Compare(s, o) < 0
}

// Returns true if the scalar is grater than the input scalar (e.g. the input scalar
// hapens before this scalar).
func (s *Scalar) After(o *Scalar) bool {
return Compare(s, o) > 0
}

func (s *Scalar) Size() int {
return scalarSize
}

func (s *Scalar) Encode(e *lani.Encoder) (n int, err error) {
var m int
if m, err = e.EncodeUint32(s.PID); err != nil {
return n, err
}
n += m

if m, err = e.EncodeUint64(s.VID); err != nil {
return n, err
}
n += m

return n, nil
}

func (s *Scalar) Decode(d *lani.Decoder) (err error) {
if s.PID, err = d.DecodeUint32(); err != nil {
return err
}

if s.VID, err = d.DecodeUint64(); err != nil {
return err
}

return nil
}

func (s *Scalar) MarshalBinary() ([]byte, error) {
e := &lani.Encoder{}
e.Grow(s.Size())

if _, err := s.Encode(e); err != nil {
return nil, err
}

return e.Bytes(), nil
}

func (s *Scalar) UnmarshalBinary(data []byte) (err error) {
d := lani.NewDecoder(data)
return s.Decode(d)
}

func (s *Scalar) MarshalText() (_ []byte, err error) {
return []byte(s.String()), nil
}

func (s *Scalar) UnmarshalText(text []byte) (err error) {
if !scre.Match(text) {
return errors.New("could not parse text representation of a scalar")
}

parts := bytes.Split(text, []byte{'.'})

var pid uint64
if pid, err = strconv.ParseUint(string(parts[0]), 10, 32); err != nil {
panic("pid is not parseable even though regular expression matched.")
}
s.PID = uint32(pid)

if s.VID, err = strconv.ParseUint(string(parts[1]), 10, 32); err != nil {
panic("pid is not parseable even though regular expression matched.")
}

return nil
}

// Returns a scalar version representation in the form PID.VID using decimal notation.
func (s *Scalar) String() string {
return fmt.Sprintf("%d.%d", s.PID, s.VID)
}
Loading

0 comments on commit 7e3f9ee

Please sign in to comment.