diff --git a/src/Disruptor.Benchmarks/WaitStrategies/PingPongAsyncWaitStrategyBenchmarks.cs b/src/Disruptor.Benchmarks/WaitStrategies/PingPongAsyncWaitStrategyBenchmarks.cs index 2d0b72e0..4e0930e6 100644 --- a/src/Disruptor.Benchmarks/WaitStrategies/PingPongAsyncWaitStrategyBenchmarks.cs +++ b/src/Disruptor.Benchmarks/WaitStrategies/PingPongAsyncWaitStrategyBenchmarks.cs @@ -13,14 +13,14 @@ public class PingPongAsyncWaitStrategyBenchmarks : IDisposable private readonly AsyncWaitStrategy _pongWaitStrategy = new(); private readonly Sequence _pingCursor = new(); private readonly Sequence _pongCursor = new(); + private readonly AsyncWaitState _pingAsyncWaitState; + private readonly AsyncWaitState _pongAsyncWaitState; private readonly Task _pongTask; - private readonly DependentSequenceGroup _pingDependentSequences; - private readonly DependentSequenceGroup _pongDependentSequences; public PingPongAsyncWaitStrategyBenchmarks() { - _pingDependentSequences = new DependentSequenceGroup(_pingCursor); - _pongDependentSequences = new DependentSequenceGroup(_pongCursor); + _pingAsyncWaitState = new AsyncWaitState(new DependentSequenceGroup(_pingCursor), _cancellationTokenSource.Token); + _pongAsyncWaitState = new AsyncWaitState(new DependentSequenceGroup(_pongCursor), _cancellationTokenSource.Token); _pongTask = Task.Run(RunPong); } @@ -40,8 +40,10 @@ private async Task RunPong() { sequence++; - await _pingWaitStrategy.WaitForAsync(sequence, _pingDependentSequences, _cancellationTokenSource.Token).ConfigureAwait(false); + // Wait for ping + await _pingWaitStrategy.WaitForAsync(sequence, _pingAsyncWaitState).ConfigureAwait(false); + // Publish pong _pongCursor.SetValue(sequence); _pongWaitStrategy.SignalAllWhenBlocking(); } @@ -62,10 +64,12 @@ public async Task Run() for (var s = start; s < end; s++) { + // Publish ping _pingCursor.SetValue(s); _pingWaitStrategy.SignalAllWhenBlocking(); - await _pongWaitStrategy.WaitForAsync(s, _pongDependentSequences, _cancellationTokenSource.Token).ConfigureAwait(false); + // Wait for pong + await _pongWaitStrategy.WaitForAsync(s, _pongAsyncWaitState).ConfigureAwait(false); } } } diff --git a/src/Disruptor.Tests/AsyncWaitStrategyTests.cs b/src/Disruptor.Tests/AsyncWaitStrategyTests.cs index f6cf8933..ea54682c 100644 --- a/src/Disruptor.Tests/AsyncWaitStrategyTests.cs +++ b/src/Disruptor.Tests/AsyncWaitStrategyTests.cs @@ -19,12 +19,12 @@ public void ShouldWaitFromMultipleThreadsAsync() var waitTask1 = Task.Run(async () => { - waitResult1.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor), CancellationToken)); + waitResult1.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor), CancellationToken))); Thread.Sleep(1); sequence1.SetValue(10); }); - var waitTask2 = Task.Run(async () => waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, sequence1), CancellationToken))); + var waitTask2 = Task.Run(async () => waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, sequence1), CancellationToken)))); // Ensure waiting tasks are blocked AssertIsNotCompleted(waitResult1.Task); @@ -62,12 +62,12 @@ public void ShouldWaitFromMultipleThreadsSyncAndAsync() var waitTask2 = Task.Run(async () => { - waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, sequence1), CancellationToken)); + waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, sequence1), CancellationToken))); Thread.Sleep(1); sequence2.SetValue(10); }); - var waitTask3 = Task.Run(async () => waitResult3.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, sequence2), CancellationToken))); + var waitTask3 = Task.Run(async () => waitResult3.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, sequence2), CancellationToken)))); // Ensure waiting tasks are blocked AssertIsNotCompleted(waitResult1.Task); @@ -103,7 +103,7 @@ public void ShouldWaitAfterCancellationAsync() { try { - await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, dependentSequence), CancellationToken); + await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, dependentSequence), CancellationToken)); } catch (Exception e) { @@ -129,7 +129,7 @@ public void ShouldUnblockAfterCancellationAsync() { try { - await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, dependentSequence), CancellationToken); + await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, dependentSequence), CancellationToken)); } catch (Exception e) { @@ -165,7 +165,7 @@ public void ShouldWaitMultipleTimesAsync() for (var i = 0; i < 500; i++) { - await waitStrategy.WaitForAsync(i, dependentSequences, cancellationTokenSource.Token).ConfigureAwait(false); + await waitStrategy.WaitForAsync(i, new AsyncWaitState(dependentSequences, cancellationTokenSource.Token)).ConfigureAwait(false); sequence1.SetValue(i); } }); @@ -177,7 +177,7 @@ public void ShouldWaitMultipleTimesAsync() for (var i = 0; i < 500; i++) { - await waitStrategy.WaitForAsync(i, dependentSequences, cancellationTokenSource.Token).ConfigureAwait(false); + await waitStrategy.WaitForAsync(i, new AsyncWaitState(dependentSequences, cancellationTokenSource.Token)).ConfigureAwait(false); } }); diff --git a/src/Disruptor.Tests/AsyncWaitStrategyTestsWithTimeout.cs b/src/Disruptor.Tests/AsyncWaitStrategyTestsWithTimeout.cs index fafdc221..56618845 100644 --- a/src/Disruptor.Tests/AsyncWaitStrategyTestsWithTimeout.cs +++ b/src/Disruptor.Tests/AsyncWaitStrategyTestsWithTimeout.cs @@ -66,12 +66,12 @@ public void ShouldWaitFromMultipleThreadsWithTimeoutsAsync() var waitTask1 = Task.Run(async () => { - waitResult1.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor), CancellationToken)); + waitResult1.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor), CancellationToken))); Thread.Sleep(1); sequence1.SetValue(10); }); - var waitTask2 = Task.Run(async () => waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, sequence1), CancellationToken))); + var waitTask2 = Task.Run(async () => waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, sequence1), CancellationToken)))); // Ensure waiting tasks are blocked AssertIsNotCompleted(waitResult1.Task); diff --git a/src/Disruptor.Tests/DisruptorStressTest.cs b/src/Disruptor.Tests/DisruptorStressTest.cs index 8a0b5efc..ef769a80 100644 --- a/src/Disruptor.Tests/DisruptorStressTest.cs +++ b/src/Disruptor.Tests/DisruptorStressTest.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Disruptor.Dsl; @@ -27,7 +28,7 @@ public void ShouldHandleLotsOfThreads_AsyncBatchEventHandler() ShouldHandleLotsOfThreads(new AsyncWaitStrategy(), 2_000_000); } - private static void ShouldHandleLotsOfThreads(IWaitStrategy waitStrategy, int iterations) where T : IHandler, new() + private static void ShouldHandleLotsOfThreads(IWaitStrategy waitStrategy, int iterations) where T : ITestHandler, new() { var disruptor = new Disruptor(TestEvent.Factory, 65_536, TaskScheduler.Current, ProducerType.Multi, waitStrategy); var ringBuffer = disruptor.RingBuffer; @@ -36,7 +37,6 @@ public void ShouldHandleLotsOfThreads_AsyncBatchEventHandler() var publisherCount = Math.Clamp(Environment.ProcessorCount / 2, 1, 8); var handlerCount = Math.Clamp(Environment.ProcessorCount / 2, 1, 8); - var end = new CountdownEvent(publisherCount); var start = new CountdownEvent(publisherCount); var handlers = new T[handlerCount]; @@ -50,26 +50,15 @@ public void ShouldHandleLotsOfThreads_AsyncBatchEventHandler() var publishers = new Publisher[publisherCount]; for (var i = 0; i < publishers.Length; i++) { - publishers[i] = new Publisher(ringBuffer, iterations, start, end); + publishers[i] = new Publisher(ringBuffer, iterations, start); } disruptor.Start(); - foreach (var publisher in publishers) - { - Task.Run(publisher.Run); - } - - end.Wait(); - - var spinWait = new SpinWait(); + var publisherTasks = publishers.Select(x => Task.Run(x.Run)).ToArray(); + Task.WaitAll(publisherTasks); - while (ringBuffer.Cursor < (iterations - 1)) - { - spinWait.SpinOnce(); - } - - disruptor.Shutdown(); + disruptor.Shutdown(TimeSpan.FromSeconds(10)); foreach (var publisher in publishers) { @@ -78,12 +67,12 @@ public void ShouldHandleLotsOfThreads_AsyncBatchEventHandler() foreach (var handler in handlers) { - Assert.That(handler.MessagesSeen, Is.Not.EqualTo(0)); + Assert.That(handler.MessagesSeen, Is.EqualTo(iterations * publishers.Length)); Assert.That(handler.FailureCount, Is.EqualTo(0)); } } - private interface IHandler + private interface ITestHandler { int FailureCount { get; } int MessagesSeen { get; } @@ -91,7 +80,7 @@ private interface IHandler void Register(Disruptor disruptor); } - private class TestEventHandler : IEventHandler, IHandler + private class TestEventHandler : IEventHandler, ITestHandler { public int FailureCount { get; private set; } public int MessagesSeen { get; private set; } @@ -112,7 +101,7 @@ public void OnEvent(TestEvent @event, long sequence, bool endOfBatch) } } - private class TestBatchEventHandler : IBatchEventHandler, IHandler + private class TestBatchEventHandler : IBatchEventHandler, ITestHandler { public int FailureCount { get; private set; } public int MessagesSeen { get; private set; } @@ -139,7 +128,7 @@ public void OnBatch(EventBatch batch, long sequence) } } - private class TestAsyncBatchEventHandler : IAsyncBatchEventHandler, IHandler + private class TestAsyncBatchEventHandler : IAsyncBatchEventHandler, ITestHandler { public int FailureCount { get; private set; } public int MessagesSeen { get; private set; } @@ -171,16 +160,14 @@ public async ValueTask OnBatch(EventBatch batch, long sequence) private class Publisher { private readonly RingBuffer _ringBuffer; - private readonly CountdownEvent _end; private readonly CountdownEvent _start; private readonly int _iterations; public bool Failed; - public Publisher(RingBuffer ringBuffer, int iterations, CountdownEvent start, CountdownEvent end) + public Publisher(RingBuffer ringBuffer, int iterations, CountdownEvent start) { _ringBuffer = ringBuffer; - _end = end; _start = start; _iterations = iterations; } @@ -195,22 +182,18 @@ public void Run() var i = _iterations; while (--i != -1) { - var next = _ringBuffer.Next(); - var testEvent = _ringBuffer[next]; - testEvent.Sequence = next; - testEvent.A = next + 13; - testEvent.B = next - 7; - _ringBuffer.Publish(next); + var sequence = _ringBuffer.Next(); + var testEvent = _ringBuffer[sequence]; + testEvent.Sequence = sequence; + testEvent.A = sequence + 13; + testEvent.B = sequence - 7; + _ringBuffer.Publish(sequence); } } catch (Exception) { Failed = true; } - finally - { - _end.Signal(); - } } } diff --git a/src/Disruptor.Tests/Dsl/DisruptorTests.cs b/src/Disruptor.Tests/Dsl/DisruptorTests.cs index 2ce9c457..f0d27152 100644 --- a/src/Disruptor.Tests/Dsl/DisruptorTests.cs +++ b/src/Disruptor.Tests/Dsl/DisruptorTests.cs @@ -104,8 +104,8 @@ public void ShouldPublishAndHandleEvent_AsyncBatchEventHandler() var eventCounter = new CountdownEvent(2); var values = new List(); - _disruptor.HandleEventsWith(new TestBatchEventHandler(e => values.Add(e.Value))) - .Then(new TestBatchEventHandler(e => eventCounter.Signal())); + _disruptor.HandleEventsWith(new TestAsyncBatchEventHandler(e => values.Add(e.Value))) + .Then(new TestAsyncBatchEventHandler(e => eventCounter.Signal())); _disruptor.Start(); diff --git a/src/Disruptor/AsyncEventStream.cs b/src/Disruptor/AsyncEventStream.cs index 02fdf66f..b49d0b7d 100644 --- a/src/Disruptor/AsyncEventStream.cs +++ b/src/Disruptor/AsyncEventStream.cs @@ -119,12 +119,14 @@ private class Enumerator : IAsyncEnumerator> private readonly Sequence _sequence; private readonly CancellationTokenRegistration _cancellationTokenRegistration; private readonly CancellationTokenSource _linkedTokenSource; + private readonly AsyncWaitState _asyncWaitState; public Enumerator(AsyncEventStream asyncEventStream, Sequence sequence, CancellationToken streamCancellationToken, CancellationToken enumeratorCancellationToken) { _asyncEventStream = asyncEventStream; _sequence = sequence; _linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(streamCancellationToken, enumeratorCancellationToken); + _asyncWaitState = new AsyncWaitState(asyncEventStream._dependentSequences, _linkedTokenSource.Token, asyncEventStream._sequencer); _cancellationTokenRegistration = _linkedTokenSource.Token.Register(x => ((IAsyncWaitStrategy)x!).SignalAllWhenBlocking(), asyncEventStream._waitStrategy); } @@ -151,16 +153,12 @@ public async ValueTask MoveNextAsync() _linkedTokenSource.Token.ThrowIfCancellationRequested(); - var waitResult = await _asyncEventStream._waitStrategy.WaitForAsync(nextSequence, _asyncEventStream._dependentSequences, _linkedTokenSource.Token).ConfigureAwait(false); + var waitResult = await _asyncEventStream._waitStrategy.WaitForAsync(nextSequence, _asyncWaitState).ConfigureAwait(false); if (waitResult.UnsafeAvailableSequence < nextSequence) continue; - var availableSequence = _asyncEventStream._sequencer.GetHighestPublishedSequence(nextSequence, waitResult.UnsafeAvailableSequence); - if (availableSequence >= nextSequence) - { - Current = _asyncEventStream._dataProvider.GetBatch(nextSequence, availableSequence); - return true; - } + Current = _asyncEventStream._dataProvider.GetBatch(nextSequence, waitResult.UnsafeAvailableSequence); + return true; } } } diff --git a/src/Disruptor/AsyncSequenceBarrier.cs b/src/Disruptor/AsyncSequenceBarrier.cs index 21ce5cdc..45a3aff7 100644 --- a/src/Disruptor/AsyncSequenceBarrier.cs +++ b/src/Disruptor/AsyncSequenceBarrier.cs @@ -13,6 +13,7 @@ public sealed class AsyncSequenceBarrier private readonly IAsyncWaitStrategy _waitStrategy; private readonly DependentSequenceGroup _dependentSequences; private CancellationTokenSource _cancellationTokenSource; + private AsyncWaitState _asyncWaitState; public AsyncSequenceBarrier(ISequencer sequencer, IWaitStrategy waitStrategy, DependentSequenceGroup dependentSequences) { @@ -23,6 +24,7 @@ public AsyncSequenceBarrier(ISequencer sequencer, IWaitStrategy waitStrategy, De _waitStrategy = asyncWaitStrategy; _dependentSequences = dependentSequences; _cancellationTokenSource = new CancellationTokenSource(); + _asyncWaitState = new AsyncWaitState(dependentSequences, _cancellationTokenSource.Token, _sequencer); } public DependentSequenceGroup DependentSequences => _dependentSequences; @@ -65,26 +67,13 @@ public ValueTask WaitForAsync(long return new ValueTask(_sequencer.GetHighestPublishedSequence(sequence, availableSequence)); } - if (typeof(TSequenceBarrierOptions) == typeof(ISequenceBarrierOptions.IsDependentSequencePublished)) - { - return InvokeWaitStrategy(sequence); - } - - return InvokeWaitStrategyAndWaitForPublishedSequence(sequence); + return InvokeWaitStrategy(sequence); } [MethodImpl(MethodImplOptions.NoInlining)] private ValueTask InvokeWaitStrategy(long sequence) { - return _waitStrategy.WaitForAsync(sequence, _dependentSequences, _cancellationTokenSource.Token); - } - - [MethodImpl(MethodImplOptions.NoInlining)] - private async ValueTask InvokeWaitStrategyAndWaitForPublishedSequence(long sequence) - { - var waitResult = await _waitStrategy.WaitForAsync(sequence, _dependentSequences, _cancellationTokenSource.Token).ConfigureAwait(false); - - return waitResult.UnsafeAvailableSequence >= sequence ? _sequencer.GetHighestPublishedSequence(sequence, waitResult.UnsafeAvailableSequence) : waitResult; + return _waitStrategy.WaitForAsync(sequence, _asyncWaitState); } public void ResetProcessing() @@ -93,6 +82,7 @@ public void ResetProcessing() // has no finalizer and no unmanaged resources to release. _cancellationTokenSource = new CancellationTokenSource(); + _asyncWaitState = new AsyncWaitState(_dependentSequences, _cancellationTokenSource.Token, _sequencer); } public void CancelProcessing() diff --git a/src/Disruptor/AsyncWaitState.cs b/src/Disruptor/AsyncWaitState.cs new file mode 100644 index 00000000..28b6feb9 --- /dev/null +++ b/src/Disruptor/AsyncWaitState.cs @@ -0,0 +1,117 @@ +using System; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Sources; + +namespace Disruptor; + +/// +/// State . +/// Used to store per-caller synchronization primitives. +/// +public class AsyncWaitState +{ + private readonly ValueTaskSource _valueTaskSource; + private readonly CancellationToken _cancellationToken; + private readonly DependentSequenceGroup _dependentSequences; + private readonly ISequencer? _sequencer; + private ManualResetValueTaskSourceCore _valueTaskSourceCore; + private long _sequence; + + public AsyncWaitState(DependentSequenceGroup dependentSequences, CancellationToken cancellationToken, ISequencer? sequencer = null) + { + _valueTaskSource = new(this); + _cancellationToken = cancellationToken; + _dependentSequences = dependentSequences; + _sequencer = sequencer != null && IsSequencerRequired(sequencer, dependentSequences) ? sequencer : null; + _valueTaskSourceCore = new() { RunContinuationsAsynchronously = true }; + } + + public long CursorValue => _dependentSequences.CursorValue; + + public CancellationToken CancellationToken => _cancellationToken; + + public void ThrowIfCancellationRequested() + { + CancellationToken.ThrowIfCancellationRequested(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public long AggressiveSpinWaitFor(long sequence) + { + return _dependentSequences.AggressiveSpinWaitFor(sequence, _cancellationToken); + } + + public void Signal() + { + _valueTaskSourceCore.SetResult(true); + } + + public ValueTask Wait(long sequence) + { + _valueTaskSourceCore.Reset(); + _sequence = sequence; + + return new ValueTask(_valueTaskSource, _valueTaskSourceCore.Version); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public SequenceWaitResult GetAvailableSequence(long sequence) + { + var result = AggressiveSpinWaitFor(sequence); + + if (_sequencer != null && result >= _sequence) + return _sequencer.GetHighestPublishedSequence(_sequence, result); + + return result; + } + + private SequenceWaitResult GetResult(short token) + { + _valueTaskSourceCore.GetResult(token); + + return GetAvailableSequence(_sequence); + } + + private ValueTaskSourceStatus GetStatus(short token) + { + return _valueTaskSourceCore.GetStatus(token); + } + + private void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) + { + _valueTaskSourceCore.OnCompleted(continuation, state, token, flags); + } + + private static bool IsSequencerRequired(ISequencer sequencer, DependentSequenceGroup dependentSequences) + { + var isDependentSequencePublished = ISequenceBarrierOptions.Get(sequencer, dependentSequences) is ISequenceBarrierOptions.IsDependentSequencePublished; + return !isDependentSequencePublished; + } + + private class ValueTaskSource : IValueTaskSource + { + private readonly AsyncWaitState _asyncWaitState; + + public ValueTaskSource(AsyncWaitState asyncWaitState) + { + _asyncWaitState = asyncWaitState; + } + + public SequenceWaitResult GetResult(short token) + { + return _asyncWaitState.GetResult(token); + } + + public ValueTaskSourceStatus GetStatus(short token) + { + return _asyncWaitState.GetStatus(token); + } + + public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) + { + _asyncWaitState.OnCompleted(continuation, state, token, flags); + } + } +} diff --git a/src/Disruptor/AsyncWaitStrategy.cs b/src/Disruptor/AsyncWaitStrategy.cs index 4dc13423..ce6d86d3 100644 --- a/src/Disruptor/AsyncWaitStrategy.cs +++ b/src/Disruptor/AsyncWaitStrategy.cs @@ -13,7 +13,7 @@ namespace Disruptor; /// public sealed class AsyncWaitStrategy : IAsyncWaitStrategy { - private readonly List> _taskCompletionSources = new(); + private readonly List _asyncWaitStates = new(); private readonly object _gate = new(); private bool _hasSyncWaiter; @@ -46,44 +46,33 @@ public void SignalAllWhenBlocking() Monitor.PulseAll(_gate); } - foreach (var completionSource in _taskCompletionSources) + foreach (var completionSource in _asyncWaitStates) { - completionSource.TrySetResult(true); + completionSource.Signal(); } - _taskCompletionSources.Clear(); + _asyncWaitStates.Clear(); } } - public async ValueTask WaitForAsync(long sequence, DependentSequenceGroup dependentSequences, CancellationToken cancellationToken) + public ValueTask WaitForAsync(long sequence, AsyncWaitState asyncWaitState) { - while (dependentSequences.CursorValue < sequence) + if (asyncWaitState.CursorValue < sequence) { - await WaitForAsyncImpl(sequence, dependentSequences, cancellationToken).ConfigureAwait(false); - } - - return dependentSequences.AggressiveSpinWaitFor(sequence, cancellationToken); - } - - private async ValueTask WaitForAsyncImpl(long sequence, DependentSequenceGroup dependentSequences, CancellationToken cancellationToken) - { - TaskCompletionSource tcs; - - lock (_gate) - { - if (dependentSequences.CursorValue >= sequence) + lock (_gate) { - return; - } + if (asyncWaitState.CursorValue < sequence) + { + asyncWaitState.ThrowIfCancellationRequested(); - cancellationToken.ThrowIfCancellationRequested(); + _asyncWaitStates.Add(asyncWaitState); - tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _taskCompletionSources.Add(tcs); + return asyncWaitState.Wait(sequence); + } + } } - // Using cancellationToken in the await is not required because SignalAllWhenBlocking is always invoked by - // the sequencer barrier after cancellation. + var availableSequence = asyncWaitState.GetAvailableSequence(sequence); - await tcs.Task.ConfigureAwait(false); + return new ValueTask(availableSequence); } } diff --git a/src/Disruptor/IAsyncWaitStrategy.cs b/src/Disruptor/IAsyncWaitStrategy.cs index 6ff7bf23..49f998c6 100644 --- a/src/Disruptor/IAsyncWaitStrategy.cs +++ b/src/Disruptor/IAsyncWaitStrategy.cs @@ -12,9 +12,8 @@ public interface IAsyncWaitStrategy : IWaitStrategy /// Wait for the given sequence to be available. /// /// sequence to be waited on - /// sequences on which to wait - /// processing cancellation token + /// TODO /// either the sequence that is available (which may be greater than the requested sequence), or a timeout /// - ValueTask WaitForAsync(long sequence, DependentSequenceGroup dependentSequences, CancellationToken cancellationToken); + ValueTask WaitForAsync(long sequence, AsyncWaitState asyncWaitState); } diff --git a/src/Disruptor/TimeoutAsyncWaitStrategy.cs b/src/Disruptor/TimeoutAsyncWaitStrategy.cs index 252707e9..25ab69a3 100644 --- a/src/Disruptor/TimeoutAsyncWaitStrategy.cs +++ b/src/Disruptor/TimeoutAsyncWaitStrategy.cs @@ -75,43 +75,53 @@ public void SignalAllWhenBlocking() } } - public async ValueTask WaitForAsync(long sequence, DependentSequenceGroup dependentSequences, CancellationToken cancellationToken) + public async ValueTask WaitForAsync(long sequence, AsyncWaitState asyncWaitState) { - while (dependentSequences.CursorValue < sequence) + while (asyncWaitState.CursorValue < sequence) { - var waitSucceeded = await WaitForAsyncImpl(sequence, dependentSequences, cancellationToken).ConfigureAwait(false); + var waitSucceeded = await WaitForAsyncImpl(sequence, asyncWaitState).ConfigureAwait(false); if (!waitSucceeded) { return SequenceWaitResult.Timeout; } } - return dependentSequences.AggressiveSpinWaitFor(sequence, cancellationToken); + return asyncWaitState.AggressiveSpinWaitFor(sequence); } - private async ValueTask WaitForAsyncImpl(long sequence, DependentSequenceGroup dependentSequences, CancellationToken cancellationToken) + private async ValueTask WaitForAsyncImpl(long sequence, AsyncWaitState asyncWaitState) { TaskCompletionSource tcs; lock (_gate) { - if (dependentSequences.CursorValue >= sequence) + if (asyncWaitState.CursorValue >= sequence) { return true; } - cancellationToken.ThrowIfCancellationRequested(); + asyncWaitState.ThrowIfCancellationRequested(); tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); _taskCompletionSources.Add(tcs); } - // Using cancellationToken in the await is not required because SignalAllWhenBlocking is always invoked by - // the sequencer barrier after cancellation. + using (var cts = new CancellationTokenSource()) + { + var delayTask = Task.Delay(_timeoutMilliseconds, cts.Token); - // ReSharper disable once MethodSupportsCancellation - await Task.WhenAny(tcs.Task, Task.Delay(_timeoutMilliseconds)).ConfigureAwait(false); + var resultTask = await Task.WhenAny(tcs.Task, delayTask).ConfigureAwait(false); + if (resultTask == delayTask) + { + return false; + } - return tcs.Task.IsCompleted; + // Cancel the timer task so that it does not fire + cts.Cancel(); + + // tcs.Task is not awaited because it cannot possibly throw + + return true; + } } }