From 698d6d8ce9eb46f9570896331482810144c22657 Mon Sep 17 00:00:00 2001 From: John Date: Tue, 22 Oct 2024 15:59:32 +0800 Subject: [PATCH] Fix queue potential bugs --- queue.go | 26 +++++++------------------- queue_test.go | 47 ++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/queue.go b/queue.go index df9b3ee..60896c3 100644 --- a/queue.go +++ b/queue.go @@ -547,11 +547,11 @@ func (q *BufferedChannelQueue[T]) freeNodePool() { break } + q.lock.Lock() if q.pool.nodeCount > q.nodeHookPoolSize { - q.lock.Lock() q.pool.KeepNodePoolCount(q.nodeHookPoolSize) - q.lock.Unlock() } + q.lock.Unlock() } } @@ -589,14 +589,8 @@ func (q *BufferedChannelQueue[T]) loadFromPool() { } func (q *BufferedChannelQueue[T]) notifyWorkers() { - q.lock.RLock() - if q.pool.Count() > 0 { - q.loadWorkerCh.Offer(1) - } - if q.pool.nodeCount > q.nodeHookPoolSize { - q.freeNodeWorkerCh.Offer(1) - } - q.lock.RUnlock() + q.loadWorkerCh.Offer(1) + q.freeNodeWorkerCh.Offer(1) } // SetBufferSizeMaximum Set MaximumBufferSize(maximum number of buffered items outside the ChannelQueue) @@ -656,6 +650,9 @@ func (q *BufferedChannelQueue[T]) Count() int { return 0 } + q.lock.RLock() + defer q.lock.RUnlock() + return len(q.blockingQueue) + q.pool.Count() } @@ -711,9 +708,6 @@ func (q *BufferedChannelQueue[T]) Put(val T) error { // Take Take the T val(blocking) func (q *BufferedChannelQueue[T]) Take() (T, error) { - // q.lock.RLock() - // defer q.lock.RUnlock() - if q.isClosed.Get() { return *new(T), ErrQueueIsClosed } @@ -725,9 +719,6 @@ func (q *BufferedChannelQueue[T]) Take() (T, error) { // TakeWithTimeout Take the T val(blocking), with timeout func (q *BufferedChannelQueue[T]) TakeWithTimeout(timeout time.Duration) (T, error) { - // q.lock.RLock() - // defer q.lock.RUnlock() - if q.isClosed.Get() { return *new(T), ErrQueueIsClosed } @@ -775,9 +766,6 @@ func (q *BufferedChannelQueue[T]) Offer(val T) error { // Poll Poll the T val(non-blocking) func (q *BufferedChannelQueue[T]) Poll() (T, error) { - // q.lock.RLock() - // defer q.lock.RUnlock() - if q.isClosed.Get() { return *new(T), ErrQueueIsClosed } diff --git a/queue_test.go b/queue_test.go index 7ee7a5e..b379af8 100644 --- a/queue_test.go +++ b/queue_test.go @@ -301,10 +301,51 @@ func TestNewBufferedChannelQueue(t *testing.T) { assert.Equal(t, nil, err) // Async + asyncTaskDone := make(chan bool) + + bufferedChannelQueue.SetBufferSizeMaximum(6) + timeout = 2 * time.Millisecond + go func() { + time.Sleep(timeout) + result, err = bufferedChannelQueue.TakeWithTimeout(timeout) + assert.Equal(t, nil, err) + assert.Equal(t, 1, result) + result, err = bufferedChannelQueue.TakeWithTimeout(timeout) + assert.Equal(t, nil, err) + assert.Equal(t, 2, result) + result, err = bufferedChannelQueue.TakeWithTimeout(timeout) + assert.Equal(t, nil, err) + assert.Equal(t, 3, result) + result, err = bufferedChannelQueue.TakeWithTimeout(timeout) + assert.Equal(t, nil, err) + assert.Equal(t, 4, result) + result, err = bufferedChannelQueue.TakeWithTimeout(timeout) + assert.Equal(t, nil, err) + assert.Equal(t, 5, result) + result, err = bufferedChannelQueue.TakeWithTimeout(timeout) + assert.Equal(t, nil, err) + assert.Equal(t, 6, result) + asyncTaskDone <- true + }() + go func() { + err = bufferedChannelQueue.Put(1) + assert.Equal(t, nil, err) + err = bufferedChannelQueue.Put(2) + assert.Equal(t, nil, err) + err = bufferedChannelQueue.Put(3) + assert.Equal(t, nil, err) + err = bufferedChannelQueue.Put(4) + assert.Equal(t, nil, err) + err = bufferedChannelQueue.Put(5) + assert.Equal(t, nil, err) + err = bufferedChannelQueue.Put(6) + assert.Equal(t, nil, err) + }() + + <-asyncTaskDone bufferedChannelQueue.SetBufferSizeMaximum(10000) - timeout = 1 * time.Millisecond - asyncTaskDone := make(chan bool) + timeout = 10 * time.Millisecond go func() { for i := 1; i <= 10000; i++ { result, err := bufferedChannelQueue.TakeWithTimeout(timeout) @@ -331,6 +372,6 @@ func TestNewBufferedChannelQueue(t *testing.T) { time.Sleep(1 * timeout) - assert.GreaterOrEqual(t, 100, bufferedChannelQueue.pool.nodeCount) + assert.GreaterOrEqual(t, bufferedChannelQueue.pool.nodeCount, 100) close(asyncTaskDone) }