diff --git a/src/tools.jl b/src/tools.jl index b0d95fd8..8029afe0 100644 --- a/src/tools.jl +++ b/src/tools.jl @@ -31,8 +31,8 @@ May be used to mark a region in parallel code to be executed by a single task on See [`try_enter!`](@ref) and [`reset!`](@ref). """ mutable struct OnlyOneRegion - @atomic latch::Bool - OnlyOneRegion() = new(false) + @atomic task::Union{Task, Nothing} + OnlyOneRegion() = new(nothing) end """ @@ -62,15 +62,14 @@ end ``` """ function try_enter!(f, s::OnlyOneRegion) - latch = @atomic :monotonic s.latch - if latch + ct = current_task() + t = @atomic :monotonic s.task + if !isnothing(t) && ct != t return end - (_, success) = @atomicreplace s.latch false=>true - if !success - return + if ct == t || (@atomicreplace s.task nothing=>ct).success + f() end - f() return end @@ -78,8 +77,8 @@ end Reset the `OnlyOneRegion` (so that it can be used again). """ function reset!(s::OnlyOneRegion) - @atomicreplace s.latch true=>false - nothing + @atomic s.task = nothing + return end """ diff --git a/test/runtests.jl b/test/runtests.jl index c436ab05..5681665f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -454,6 +454,22 @@ end catch ErrorException @test false end + + x = 0 + y = 0 + try + @tasks for i in 1:10 + @set ntasks = 2 + + y += 1 # not safe (race condition) + @only_one begin + x += 1 # parallel-safe because only a single task will execute this + end + end + @test x == 5 # a single task should have incremented x 5 times + catch ErrorException + @test false + end end @testset "@only_one + @one_by_one" begin