From 325b71da0c6d5593b8b0b18ccbfe6f22ffc229b0 Mon Sep 17 00:00:00 2001 From: Ashton Kinslow Date: Thu, 9 Nov 2023 01:20:48 -0500 Subject: [PATCH] initial lockfree attempt --- autopaho/queue/benchmarks/bench_test.go | 113 ++++++++++++ autopaho/queue/lockfree/queue.go | 135 ++++++++++++++ autopaho/queue/lockfree/queue_test.go | 226 ++++++++++++++++++++++++ autopaho/queue/queue.go | 1 + 4 files changed, 475 insertions(+) create mode 100644 autopaho/queue/benchmarks/bench_test.go create mode 100644 autopaho/queue/lockfree/queue.go create mode 100644 autopaho/queue/lockfree/queue_test.go diff --git a/autopaho/queue/benchmarks/bench_test.go b/autopaho/queue/benchmarks/bench_test.go new file mode 100644 index 0000000..0493f70 --- /dev/null +++ b/autopaho/queue/benchmarks/bench_test.go @@ -0,0 +1,113 @@ +package benchmarks + +import ( + "runtime" + "strings" + "testing" + + "github.com/eclipse/paho.golang/autopaho/queue" + "github.com/eclipse/paho.golang/autopaho/queue/lockfree" + "github.com/eclipse/paho.golang/autopaho/queue/memory" +) + +func benchmarkConcurrentEnqueueDequeue(b *testing.B, q queue.Queue) { + data := strings.NewReader("test data") + workers := runtime.GOMAXPROCS(0) // Use as many goroutines as there are available CPUs + + b.ResetTimer() + b.ReportAllocs() + b.SetParallelism(workers) // Set the number of goroutines to use in parallel benchmarks + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := q.Enqueue(data); err != nil { + b.Fatal(err) + } + if err := q.Dequeue(); err != nil { + b.Fatal(err) + } + } + }) +} + +func seedQueue(q queue.Queue, count int) queue.Queue { + for i := 0; i < count; i++ { + data := strings.NewReader("test data") + q.Enqueue(data) + } + return q +} + +func benchmarkEnqueue(b *testing.B, q queue.Queue) { + data := strings.NewReader("test data") + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if err := q.Enqueue(data); err != nil { + b.Fatal(err) + } + } +} + +func benchmarkPeek(b *testing.B, q queue.Queue) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if _, err := q.Peek(); err != nil { + b.Fatal(err) + } + } +} + +func benchmarkDequeue(b *testing.B, q queue.Queue) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + q.Dequeue() + } +} + +func BenchmarkEnqueueMemory(b *testing.B) { + queue := memory.New() + benchmarkEnqueue(b, queue) +} + +func BenchmarkEnqueueLockFree(b *testing.B) { + queue := lockfree.New() + benchmarkEnqueue(b, queue) +} + +func BenchmarkDequeueMemory(b *testing.B) { + q := memory.New() + seedQueue(q, 1000) + benchmarkEnqueue(b, q) +} + +func BenchmarkDequeueLockFree(b *testing.B) { + q := lockfree.New() + seedQueue(q, 1000) + benchmarkDequeue(b, q) +} + +func BenchmarkPeakMemory(b *testing.B) { + q := memory.New() + seedQueue(q, 1000) + benchmarkPeek(b, q) +} + +func BenchmarkPeakLockFree(b *testing.B) { + q := lockfree.New() + seedQueue(q, 1000) + benchmarkPeek(b, q) +} + +func BenchmarkConcurrentMemory(b *testing.B) { + q := memory.New() + benchmarkConcurrentEnqueueDequeue(b, q) +} + +func BenchmarkConcurrentLockFree(b *testing.B) { + q := lockfree.New() + benchmarkConcurrentEnqueueDequeue(b, q) +} diff --git a/autopaho/queue/lockfree/queue.go b/autopaho/queue/lockfree/queue.go new file mode 100644 index 0000000..5eb6dfd --- /dev/null +++ b/autopaho/queue/lockfree/queue.go @@ -0,0 +1,135 @@ +package lockfree + +import ( + "bytes" + "io" + "sync/atomic" + "unsafe" + + "github.com/eclipse/paho.golang/autopaho/queue" +) + +type Queue struct { + head unsafe.Pointer // *node + tail unsafe.Pointer // *node + waitChan unsafe.Pointer // *chan struct{} +} + +type node struct { + value []byte + next unsafe.Pointer // *node +} + +// NewLockFree creates a queue with a dummy node. +func New() *Queue { + dummy := &node{} + return &Queue{ + head: unsafe.Pointer(dummy), + tail: unsafe.Pointer(dummy), + } +} + +// Enqueue adds an item to the queue. +func (q *Queue) Enqueue(p io.Reader) error { + data, err := io.ReadAll(p) + if err != nil { + return err + } + + n := &node{value: data} + for { + tail := (*node)(atomic.LoadPointer(&q.tail)) + next := (*node)(atomic.LoadPointer(&tail.next)) + if tail == (*node)(atomic.LoadPointer(&q.tail)) { // Still the tail? + if next == nil { + if atomic.CompareAndSwapPointer(&tail.next, nil, unsafe.Pointer(n)) { + atomic.CompareAndSwapPointer(&q.tail, unsafe.Pointer(tail), unsafe.Pointer(n)) + // Signal that the queue is not empty if needed + q.signalNotEmpty() + return nil + } + } else { + atomic.CompareAndSwapPointer(&q.tail, unsafe.Pointer(tail), unsafe.Pointer(next)) + } + } + } +} + +// Dequeue removes the oldest item from the queue. +func (q *Queue) Dequeue() error { + for { + head := (*node)(atomic.LoadPointer(&q.head)) + tail := (*node)(atomic.LoadPointer(&q.tail)) + next := (*node)(atomic.LoadPointer(&head.next)) + if head == (*node)(atomic.LoadPointer(&q.head)) { // Still the head? + if head == tail { + if next == nil { + return queue.ErrEmpty // Queue is empty + } + // Tail falling behind, advance it + atomic.CompareAndSwapPointer(&q.tail, unsafe.Pointer(tail), unsafe.Pointer(next)) + } else { + // Read value before CAS, otherwise another dequeue might free the next node + if atomic.CompareAndSwapPointer(&q.head, unsafe.Pointer(head), unsafe.Pointer(next)) { + return nil + } + } + } + } +} + +// Peek retrieves the oldest item from the queue without removing it. +func (q *Queue) Peek() (io.ReadCloser, error) { + for { + head := (*node)(atomic.LoadPointer(&q.head)) + next := (*node)(atomic.LoadPointer(&head.next)) + if next != nil { // There is an item in the queue + return io.NopCloser(bytes.NewReader(next.value)), nil + } + + if atomic.LoadPointer(&q.waitChan) != nil { + // The wait channel is set, meaning the queue may not be empty + continue // Retry the loop since the queue state may have changed + } + // The queue is empty + return nil, queue.ErrEmpty + } +} + +// Wait returns a channel that is closed when there is something in the queue. +func (q *Queue) Wait() chan struct{} { + for { + if !q.isEmpty() { + // If the queue is not empty, return a closed channel + c := make(chan struct{}) + close(c) + return c + } + + // Attempt to create a wait channel if it doesn't exist + if atomic.LoadPointer(&q.waitChan) == nil { + newCh := make(chan struct{}) + if atomic.CompareAndSwapPointer(&q.waitChan, nil, unsafe.Pointer(&newCh)) { + return newCh + } + } + } +} + +// isEmpty checks if the queue is empty. +func (q *Queue) isEmpty() bool { + head := (*node)(atomic.LoadPointer(&q.head)) + tail := (*node)(atomic.LoadPointer(&q.tail)) + next := (*node)(atomic.LoadPointer(&head.next)) + return head == tail && next == nil +} + +// signalNotEmpty signals that the queue is not empty. +func (q *Queue) signalNotEmpty() { + chPtr := atomic.LoadPointer(&q.waitChan) + if chPtr != nil { + ch := *(*chan struct{})(chPtr) + close(ch) // Close the channel to signal that the queue is not empty + atomic.StorePointer(&q.waitChan, nil) // Reset the wait channel pointer + } +} diff --git a/autopaho/queue/lockfree/queue_test.go b/autopaho/queue/lockfree/queue_test.go new file mode 100644 index 0000000..753c55a --- /dev/null +++ b/autopaho/queue/lockfree/queue_test.go @@ -0,0 +1,226 @@ +package lockfree + +import ( + "bytes" + "errors" + "fmt" + "io/ioutil" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/eclipse/paho.golang/autopaho/queue" +) + +// TestLockFree some basic tests of the queue +func TestLockFree(t *testing.T) { + q := New() + + if _, err := q.Peek(); !errors.Is(err, queue.ErrEmpty) { + t.Errorf("expected ErrEmpty, got %s", err) + } + + if err := q.Dequeue(); !errors.Is(err, queue.ErrEmpty) { + t.Errorf("expected ErrEmpty, got %s", err) + } + + queueNotEmpty := make(chan struct{}) + go func() { + <-q.Wait() + close(queueNotEmpty) + }() + time.Sleep(time.Nanosecond) // let go routine run + select { + case <-queueNotEmpty: + t.Fatalf("Wait should not return until something is in queue") + default: + } + testEntry := []byte("This is a test") + if err := q.Enqueue(bytes.NewReader(testEntry)); err != nil { + t.Fatalf("error adding to queue: %s", err) + } + select { + case <-queueNotEmpty: + case <-time.After(time.Second): + t.Fatalf("Wait should return when something is in queue") + } + + const entryFormat = "Queue entry %d for testing" + for i := 0; i < 10; i++ { + if err := q.Enqueue(bytes.NewReader([]byte(fmt.Sprintf(entryFormat, i)))); err != nil { + t.Fatalf("error adding entry %d: %s", i, err) + } + } + if err := q.Dequeue(); err != nil { + t.Fatalf("error dequeue entry: %s", err) + } + + for i := 0; i < 10; i++ { + r, err := q.Peek() + if err != nil { + t.Fatalf("error peeking entry %d: %s", i, err) + } + buf := &bytes.Buffer{} + if _, err = buf.ReadFrom(r); err != nil { + t.Fatalf("error reading entry %d: %s", i, err) + } + if err = r.Close(); err != nil { + t.Fatalf("error closing queue entry %d: %s", i, err) + } + + expected := []byte(fmt.Sprintf(entryFormat, i)) + if bytes.Compare(expected, buf.Bytes()) != 0 { + t.Fatalf("expected \"%s\", got \"%s\"", expected, buf.Bytes()) + } + if err = q.Dequeue(); err != nil { + t.Fatalf("error dequeue entry %d: %s", i, err) + } + } + + if _, err := q.Peek(); !errors.Is(err, queue.ErrEmpty) { + t.Errorf("expected ErrEmpty, got %s", err) + } + + if err := q.Dequeue(); !errors.Is(err, queue.ErrEmpty) { + t.Errorf("expected ErrEmpty, got %s", err) + } +} + +// TestMultipleWait ensures that multiple goroutines waiting on the queue +// all receive the signal when a new item is enqueued. +func TestMultipleWait(t *testing.T) { + q := New() + var wg sync.WaitGroup + waiters := 5 + + // Start multiple goroutines that will wait for the signal + for i := 0; i < waiters; i++ { + wg.Add(1) + go func() { + defer wg.Done() + waitCh := q.Wait() + <-waitCh // Wait for the signal + }() + } + + // Give the goroutines time to start and call Wait() + time.Sleep(100 * time.Millisecond) + + // Enqueue an item, which should close the wait channel and signal all waiting goroutines + err := q.Enqueue(strings.NewReader("data")) + if err != nil { + t.Fatalf("Enqueue failed: %v", err) + } + + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() // Ensure all waiting goroutines have finished + }() + + select { + case <-done: + // Test passed, all goroutines received the signal + case <-time.After(30 * time.Second): + t.Fatal("Test timed out: not all goroutines received the signal") + } +} + +// TestMultiplePeek ensures that multiple Peek calls return the correct value +// and do not remove the item from the queue. +func TestMultiplePeek(t *testing.T) { + q := New() + input := "data" + + // Enqueue an item + err := q.Enqueue(strings.NewReader(input)) + if err != nil { + t.Fatalf("Enqueue failed: %v", err) + } + + // Start multiple Peek calls + var wg sync.WaitGroup + peekers := 5 + for i := 0; i < peekers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + reader, err := q.Peek() + if err != nil { + t.Errorf("Peek failed: %v", err) + return + } + data, err := ioutil.ReadAll(reader) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + return + } + if string(data) != input { + t.Errorf("Peek returned incorrect data: got %v, want %v", string(data), input) + } + }() + } + + wg.Wait() // Ensure all Peek calls have finished + + // The item should still be in the queue after multiple Peek calls + _, err = q.Peek() + if err != nil { + t.Errorf("Item was removed from the queue after Peek: %v", err) + } +} + +// TestHighConcurrency tests the queue with a high number of concurrent Enqueue and Dequeue operations. +func TestHighConcurrency(t *testing.T) { + q := New() + var wg sync.WaitGroup + workers := 100 + itemsPerWorker := 1000 + + // Enqueue items + for i := 0; i < workers; i++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + for j := 0; j < itemsPerWorker; j++ { + data := strings.NewReader(fmt.Sprintf("%d-%d", workerID, j)) + err := q.Enqueue(data) + if err != nil { + t.Errorf("Enqueue failed: %v", err) + } + } + }(i) + } + + // Dequeue items + var dequeueCount int32 + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + err := q.Dequeue() + if err == queue.ErrEmpty { + if atomic.LoadInt32(&dequeueCount) == int32(workers*itemsPerWorker) { + return + } + continue + } else if err != nil { + t.Errorf("Dequeue failed: %v", err) + return + } + atomic.AddInt32(&dequeueCount, 1) + } + }() + } + + wg.Wait() // Wait for all operations to complete + + // Use atomic read to get the final value of dequeueCount + finalDequeueCount := atomic.LoadInt32(&dequeueCount) + if finalDequeueCount != int32(workers*itemsPerWorker) { + t.Errorf("Dequeue count mismatch: got %v, want %v", finalDequeueCount, workers*itemsPerWorker) + } +} diff --git a/autopaho/queue/queue.go b/autopaho/queue/queue.go index 342bf75..e6281ba 100644 --- a/autopaho/queue/queue.go +++ b/autopaho/queue/queue.go @@ -13,6 +13,7 @@ var ( type Queue interface { // Wait returns a channel that is closed when there is something in the queue (will return a closed channel if the // queue is empty at the time of the call) + // Can be called multiple times. Wait() chan struct{} // Enqueue add item to the queue.