diff --git a/src/Microsoft.VisualStudio.Threading.Shared/JoinableTask+JoinableTaskSynchronizationContext.cs b/src/Microsoft.VisualStudio.Threading.Shared/JoinableTask+JoinableTaskSynchronizationContext.cs index 2295fef8a..25df9658c 100644 --- a/src/Microsoft.VisualStudio.Threading.Shared/JoinableTask+JoinableTaskSynchronizationContext.cs +++ b/src/Microsoft.VisualStudio.Threading.Shared/JoinableTask+JoinableTaskSynchronizationContext.cs @@ -75,9 +75,10 @@ internal bool MainThreadAffinitized /// public override void Post(SendOrPostCallback d, object state) { - if (this.job != null) + JoinableTask job = this.job; // capture as local in case field becomes null later. + if (job != null) { - this.job.Post(d, state, this.mainThreadAffinitized); + job.Post(d, state, this.mainThreadAffinitized); } else { diff --git a/src/Microsoft.VisualStudio.Threading.Shared/JoinableTask.cs b/src/Microsoft.VisualStudio.Threading.Shared/JoinableTask.cs index 3cda19bdc..391cb2d20 100644 --- a/src/Microsoft.VisualStudio.Threading.Shared/JoinableTask.cs +++ b/src/Microsoft.VisualStudio.Threading.Shared/JoinableTask.cs @@ -465,7 +465,8 @@ private bool SynchronouslyBlockingThreadPool get { return (this.state & JoinableTaskFlags.StartedSynchronously) == JoinableTaskFlags.StartedSynchronously - && (this.state & JoinableTaskFlags.StartedOnMainThread) == JoinableTaskFlags.None; + && (this.state & JoinableTaskFlags.StartedOnMainThread) != JoinableTaskFlags.StartedOnMainThread + && (this.state & JoinableTaskFlags.CompleteRequested) != JoinableTaskFlags.CompleteRequested; } } @@ -475,7 +476,8 @@ private bool SynchronouslyBlockingMainThread get { return (this.state & JoinableTaskFlags.StartedSynchronously) == JoinableTaskFlags.StartedSynchronously - && (this.state & JoinableTaskFlags.StartedOnMainThread) == JoinableTaskFlags.StartedOnMainThread; + && (this.state & JoinableTaskFlags.StartedOnMainThread) == JoinableTaskFlags.StartedOnMainThread + && (this.state & JoinableTaskFlags.CompleteRequested) != JoinableTaskFlags.CompleteRequested; } } diff --git a/src/Microsoft.VisualStudio.Threading.Tests.Shared/JoinableTaskTests.cs b/src/Microsoft.VisualStudio.Threading.Tests.Shared/JoinableTaskTests.cs index a4b82dd02..2fb6aa437 100644 --- a/src/Microsoft.VisualStudio.Threading.Tests.Shared/JoinableTaskTests.cs +++ b/src/Microsoft.VisualStudio.Threading.Tests.Shared/JoinableTaskTests.cs @@ -2651,6 +2651,51 @@ public void PostStress() this.PushFrame(); } + [StaFact] + public void StressFireAndForgetWorkFromCapturedSynchronizationContext() + { + for (int count = 0; count < 5000; count++) + { + var postDelegateInvoked = new ManualResetEventSlim(); + Task innerTask = null; + SynchronizationContext capturedContext = null; + bool posted = false; + + // Do the scheduling off the simulated STA thread so we can conveniently block later. + Task.Run(delegate + { + this.asyncPump.Run(delegate + { + capturedContext = SynchronizationContext.Current; + innerTask = Task.Run(async delegate + { + await Task.Yield(); + + capturedContext.Post( + s => + { + postDelegateInvoked.Set(); + }, + null); + posted = true; + }); + return TplExtensions.CompletedTask; + }); + }).Wait(); + + try + { + innerTask.GetAwaiter().GetResult(); + Assert.True(postDelegateInvoked.Wait(AsyncDelay), "Timed out waiting for posted delegate to execute. Posted: " + posted); + } + catch + { + this.Logger.WriteLine("iteration {0}", count); + throw; + } + } + } + /// /// Verifies that in the scenario when the initializing thread doesn't have a sync context at all (vcupgrade.exe) /// that reasonable behavior still occurs.