From 01deac57b3d9d9739893d2d5671119317788e6be Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Mon, 18 Mar 2024 19:33:56 +0100 Subject: [PATCH] section single; overhaul of forbody escaping --- CHANGELOG.md | 3 ++- src/macro_impl.jl | 31 +++++++++++++++++++++++++++---- src/macros.jl | 14 ++++++++++++++ src/tools.jl | 44 ++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 39 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 126 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4b0a3d8..f1712ce1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,8 @@ OhMyThreads.jl Changelog Version 0.5.1 ------------- -- ![Feature][badge-feature] Within a `@tasks` block one can now mark a section as "critical" via `@section :critical begin ... end`. +- ![Feature][badge-feature] Within a `@tasks` block one can now mark a section as "critical" via `@section :critical begin ... end`. This section will be run by one task at a time. +- ![Feature][badge-feature] Within a `@tasks` block one can now mark a section as "single" via `@section :single begin ... end`. This section will be run by a single task only. Version 0.5.0 ------------- diff --git a/src/macro_impl.jl b/src/macro_impl.jl index 235e0f76..97d46c95 100644 --- a/src/macro_impl.jl +++ b/src/macro_impl.jl @@ -1,3 +1,5 @@ +using OhMyThreads.Tools: SectionSingle, try_enter + function tasks_macro(forex) if forex.head != :for throw(ErrorException("Expected a for loop after `@tasks`.")) @@ -17,12 +19,22 @@ function tasks_macro(forex) settings = Settings() + # Escape everything in the loop body that is not used in conjuction with one of our + # "macros", e.g. @set or @local. Code inside of these macro blocks will be escaped by + # the respective "macro" handling functions below. + for i in findall(forbody.args) do arg + !(arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set")) && + !(arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@local")) && + !(arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@section")) + end + forbody.args[i] = esc(forbody.args[i]) + end + locals_before, locals_names = _maybe_handle_atlocal_block!(forbody.args) tls_names = isnothing(locals_before) ? [] : map(x -> x.args[1], locals_before) _maybe_handle_atset_block!(settings, forbody.args) setup_sections = _maybe_handle_atsection_blocks!(forbody.args) - forbody = esc(forbody) itrng = esc(itrng) itvar = esc(itvar) @@ -217,11 +229,22 @@ function _maybe_handle_atsection_blocks!(args) if kind isa QuoteNode if kind.value == :critical @gensym critical_lock - init_lock_ex = esc(:($(critical_lock) = $(ReentrantLock()))) + init_lock_ex = :($(critical_lock) = $(Base.ReentrantLock())) + # init_lock_ex = esc(:($(critical_lock) = $(Base.ReentrantLock()))) push!(setup_sections.args, init_lock_ex) args[i] = quote - lock($(critical_lock)) do - $(body) + $(esc(:lock))($(critical_lock)) do + $(esc(body)) + end + end + elseif kind.value == :single + @gensym single_section + # init_single_section_ex = esc(:($(single_section) = $(SectionSingle()))) + init_single_section_ex = :($(single_section) = $(SectionSingle())) + push!(setup_sections.args, init_single_section_ex) + args[i] = quote + Tools.try_enter($(single_section)) do + $(esc(body)) end end else diff --git a/src/macros.jl b/src/macros.jl index 7c331aa8..46252829 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -164,6 +164,7 @@ Multiple `@section` blocks are supported. ## Kinds * `:critical`: Section of code that must be executed by a single task at a time (arbitrary order). +* `:single`: Section of code that must be executed by a single task only. All other tasks will skip over this section. ## Example @@ -179,6 +180,19 @@ Multiple `@section` blocks are supported. println(i, ": after") end ``` + +```julia +@tasks for i in 1:10 + @set ntasks = 10 + + println(i, ": before") + @section :single begin + println(i, ": only printed by a single task") + sleep(1) + end + println(i, ": after") +end +``` """ macro section(args...) error("The @section macro may only be used inside of a @tasks block.") diff --git a/src/tools.jl b/src/tools.jl index 0e8983e0..f6a1d0d4 100644 --- a/src/tools.jl +++ b/src/tools.jl @@ -24,4 +24,48 @@ Return a `UInt` identifier for the current running [Task](https://docs.julialang """ taskid() = objectid(current_task()) +""" +When `try_enter(s::SectionSingle) do ... end` is called from multiple parallel tasks only +a single task will run the content of the `do ... end` block. +""" +struct SectionSingle + first::Base.RefValue{Bool} + lck::ReentrantLock + SectionSingle() = new(Ref(true), ReentrantLock()) +end + +""" + try_enter(f, s::SectionSingle) + +When called from multiple parallel tasks (on a shared `s::SectionSingle`) only a single +task will execute `f`. Typical usage: + +```julia +using OhMyThreads.Tools: SectionSingle + +s = SectionSingle() + +@tasks for i in 1:10 + @set ntasks = 10 + + println(i, ": before") + try_enter(s) do + println(i, ": only printed by a single task") + sleep(1) + end + println(i, ": after") +end +``` +""" +function try_enter(f, s::SectionSingle) + run_f = false + lock(s.lck) do + if s.first[] + run_f = true # The first task to try_enter → run f + s.first[] = false + end + end + run_f && f() +end + end # Tools diff --git a/test/runtests.jl b/test/runtests.jl index 4fe17be4..3e867b6b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -412,6 +412,45 @@ end else @test true end + + # test escaping + x = 0 + y = 0 + sao = SingleAccessOnly() + try + @tasks for i in 1:10 + @set ntasks = 10 + + y += 1 # not safe (race condition) + @section :critical begin + x += 1 # parallel-safe because inside of critical section + acquire(sao) do + sleep(0.01) + end + end + end + @test x == 10 + catch ErrorException + @test false + end + end + + @testset ":single" begin + x = 0 + y = 0 + try + @tasks for i in 1:10 + @set ntasks = 10 + + y += 1 # not safe (race condition) + @section :single begin + x += 1 # parallel-safe because inside of single section + end + end + @test x == 1 # only a single task should have incremented x + catch ErrorException + @test false + end end end;