Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SimpleBarrier and @barrier #97

Merged
merged 11 commits into from
Mar 22, 2024
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Version 0.5.1
-------------
- ![Feature][badge-feature] Within a parallel `@tasks` block one can now mark a region with `@one_by_one`. This region will be run by one task at a time ("critical region").
- ![Feature][badge-feature] Within a `@tasks` block one can now mark a region as with `@only_one`. This region will be run by a single parallel task only (other tasks will skip over it).
- ![Experimental][badge-experimental] Added tentative support for `@barrier` in `@tasks` blocks. See `?OhMyThreads.Tools.@barrier` for more information. Note that this feature is experimental and **not** part of the public API (i.e. doesn't fall under SemVer).

Version 0.5.0
-------------
Expand Down Expand Up @@ -97,6 +98,7 @@ Version 0.2.0
[badge-breaking]: https://img.shields.io/badge/BREAKING-red.svg
[badge-deprecation]: https://img.shields.io/badge/Deprecation-orange.svg
[badge-feature]: https://img.shields.io/badge/Feature-green.svg
[badge-experimental]: https://img.shields.io/badge/Experimental-yellow.svg
[badge-enhancement]: https://img.shields.io/badge/Enhancement-blue.svg
[badge-bugfix]: https://img.shields.io/badge/Bugfix-purple.svg
[badge-fix]: https://img.shields.io/badge/Fix-purple.svg
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ makedocs(;
# ],
"API" => [
"Public API" => "refs/api.md",
"Experimental" => "refs/experimental.md",
"Internal" => "refs/internal.md"
]
],
Expand Down
16 changes: 16 additions & 0 deletions docs/src/refs/experimental.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
```@meta
CollapsedDocStrings = true
```

# Experimental

!!! warning
**Everything on this page is experimental and might changed or dropped at any point!**

## References

```@autodocs
Modules = [OhMyThreads, OhMyThreads.Experimental]
Public = false
Pages = ["OhMyThreads.jl", "experimental.jl"]
```
2 changes: 1 addition & 1 deletion docs/src/refs/internal.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ CollapsedDocStrings = true
# Internal

!!! warning
**Everything on this page is internal and might change at any point!**
**Everything on this page is internal and and might changed or dropped at any point!**

## References

Expand Down
1 change: 1 addition & 0 deletions src/OhMyThreads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include("schedulers.jl")
using .Schedulers: Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler,
SerialScheduler
include("implementation.jl")
include("experimental.jl")

export @tasks, @set, @local, @one_by_one, @only_one
export treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect
Expand Down
48 changes: 48 additions & 0 deletions src/experimental.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module Experimental

"""
@barrier

This can be used inside a `@tasks for ... end` to synchronize `n` parallel tasks.
Specifically, a task can only pass the `@barrier` if `n-1` other tasks have reached it
as well. The value of `n` is determined from `@set ntasks=...`, which
is required if one wants to use `@barrier`.

Because this feature is experimental, it is required to load `@barrier` explicitly, e.g. via
`using OhMyThreads.Experimental: @barrier`.

**WARNING:** It is the responsibility of the user to ensure that the right number of tasks
actually reach the barrier. Otherwise, a **deadlock** can occur. In partictular, if the
number of iterations is not a multiple of `n`, the last few iterations (remainder) will be
run by less than `n` tasks which will never be able to pass a `@barrier`.

## Example

```julia
using OhMyThreads: @tasks

# works
@tasks for i in 1:20
@set ntasks = 20

sleep(i * 0.2)
println(i, ": before")
@barrier
println(i, ": after")
end

# wrong - deadlock!
@tasks for i in 1:22 # ntasks % niterations != 0
@set ntasks = 20

println(i, ": before")
@barrier
println(i, ": after")
end
```
"""
macro barrier(args...)
error("The @barrier macro may only be used inside of a @tasks block.")
end

end # Experimental
44 changes: 38 additions & 6 deletions src/macro_impl.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using OhMyThreads.Tools: OnlyOneRegion, try_enter!
using OhMyThreads.Tools: SimpleBarrier
using OhMyThreads: OhMyThreads

function tasks_macro(forex)
function tasks_macro(forex; __module__)
if forex.head != :for
throw(ErrorException("Expected a for loop after `@tasks`."))
else
Expand Down Expand Up @@ -28,7 +30,8 @@ function tasks_macro(forex)
!(arg isa Expr && arg.head == :macrocall &&
arg.args[1] == Symbol("@only_one")) &&
!(arg isa Expr && arg.head == :macrocall &&
arg.args[1] == Symbol("@one_by_one"))
arg.args[1] == Symbol("@one_by_one")) &&
!(arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@barrier"))
end
forbody.args[i] = esc(forbody.args[i])
end
Expand All @@ -38,6 +41,14 @@ function tasks_macro(forex)
_maybe_handle_atset_block!(settings, forbody.args)
setup_onlyone_blocks = _maybe_handle_atonlyone_blocks!(forbody.args)
setup_onebyone_blocks = _maybe_handle_atonebyone_blocks!(forbody.args)
if isdefined(__module__, Symbol("@barrier"))
if __module__.var"@barrier" != OhMyThreads.Experimental.var"@barrier"
error("There seems to be a macro `@barrier` around which isn't `OhMyThreads.Experimental.@barrier`. This isn't supported.")
end
setup_barriers = _maybe_handle_atbarriers!(forbody.args, settings)
else
setup_barriers = nothing
end

itrng = esc(itrng)
itvar = esc(itvar)
Expand All @@ -58,6 +69,7 @@ function tasks_macro(forex)
quote
$setup_onlyone_blocks
$setup_onebyone_blocks
$setup_barriers
$make_mapping_function
tmapreduce(mapping_function, $(settings.reducer),
$(itrng))
Expand All @@ -67,6 +79,7 @@ function tasks_macro(forex)
quote
$setup_onlyone_blocks
$setup_onebyone_blocks
$setup_barriers
$make_mapping_function
tmap(mapping_function, $(itrng))
end
Expand All @@ -75,6 +88,7 @@ function tasks_macro(forex)
quote
$setup_onlyone_blocks
$setup_onebyone_blocks
$setup_barriers
$make_mapping_function
tforeach(mapping_function, $(itrng))
end
Expand All @@ -91,7 +105,7 @@ function tasks_macro(forex)
for (k, v) in settings.kwargs
push!(kwexpr.args, Expr(:kw, k, v))
end
insert!(q.args[8].args, 2, kwexpr)
insert!(q.args[10].args, 2, kwexpr)

# wrap everything in a let ... end block
# and, potentially, define the `TaskLocalValue`s.
Expand All @@ -113,12 +127,11 @@ function maybe_warn_useless_init(settings)
end

Base.@kwdef mutable struct Settings
# scheduler::Expr = :(DynamicScheduler())
scheduler::Union{Expr, QuoteNode, NotGiven} = NotGiven()
reducer::Union{Expr, Symbol, NotGiven} = NotGiven()
collect::Union{Bool, NotGiven} = NotGiven()
init::Union{Expr, Symbol, NotGiven} = NotGiven()
kwargs::Vector{Pair{Symbol, Any}} = Pair{Symbol, Any}[]
kwargs::Dict{Symbol, Any} = Dict{Symbol, Any}()
end

function _maybe_handle_atlocal_block!(args)
Expand Down Expand Up @@ -220,7 +233,8 @@ function _handle_atset_single_assign!(settings, ex)
def = def isa Bool ? def : esc(def)
setfield!(settings, sym, def)
else
push!(settings.kwargs, sym => esc(def))
# push!(settings.kwargs, sym => esc(def))
settings.kwargs[sym] = esc(def)
end
end

Expand Down Expand Up @@ -263,3 +277,21 @@ function _maybe_handle_atonebyone_blocks!(args)
end
return setup_onebyone_blocks
end

function _maybe_handle_atbarriers!(args, settings)
idcs = findall(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@barrier")
end
isnothing(idcs) && return # no @barrier found
setup_barriers = quote end
for i in idcs
!haskey(settings.kwargs, :ntasks) &&
throw(ErrorException("When using `@barrier`, the number of tasks must be " *
"specified explicitly, e.g. via `@set ntasks=...`. "))
ntasks = settings.kwargs[:ntasks]
@gensym barrier
push!(setup_barriers.args, :($(barrier) = $(SimpleBarrier)($ntasks)))
args[i] = :($(esc(:wait))($(barrier)))
end
return setup_barriers
end
2 changes: 1 addition & 1 deletion src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ end
```
"""
macro tasks(args...)
Implementation.tasks_macro(args...)
Implementation.tasks_macro(args...; __module__)
end

"""
Expand Down
48 changes: 48 additions & 0 deletions src/tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,52 @@ function reset!(s::OnlyOneRegion)
nothing
end

"""
SimpleBarrier(n::Integer)

Simple reusable barrier for `n` parallel tasks.

Given `b = SimpleBarrier(n)` and `n` parallel tasks, each task that calls
`wait(b)` will block until the other `n-1` tasks have called `wait(b)` as well.

## Example
```
n = nthreads()
barrier = SimpleBarrier(n)
@sync for i in 1:n
@spawn begin
println("A")
wait(barrier) # synchronize all tasks
println("B")
wait(barrier) # synchronize all tasks (reusable)
println("C")
end
end
```
"""
mutable struct SimpleBarrier
const n::Int64
const c::Threads.Condition
cnt::Int64

function SimpleBarrier(n::Integer)
new(n, Threads.Condition(), 0)
end
end

function Base.wait(b::SimpleBarrier)
lock(b.c)
try
b.cnt += 1
if b.cnt == b.n
b.cnt = 0
notify(b.c)
else
wait(b.c)
end
finally
unlock(b.c)
end
end

end # Tools
57 changes: 56 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test, OhMyThreads
using OhMyThreads: TaskLocalValue, WithTaskLocals, @fetch, promise_task_local
using OhMyThreads.Experimental: @barrier

include("Aqua.jl")

Expand Down Expand Up @@ -470,11 +471,65 @@ end
y += 1 # parallel-safe
end
end
@test x == 1 && y == 10
@test x == 1 && y == 10
catch ErrorException
@test false
end
end
end;

@testset "@barrier" begin
@test (@tasks for i in 1:20
@set ntasks = 20
@barrier
end) |> isnothing

@test try
@macroexpand @tasks for i in 1:20
@barrier
end
false
catch
true
end

@test try
x = Threads.Atomic{Int64}(0)
y = Threads.Atomic{Int64}(0)
@tasks for i in 1:20
@set ntasks = 20

Threads.atomic_add!(x, 1)
@barrier
if x[] < 20 && y[] > 0 # x hasn't reached 20 yet and y is already > 0
error("shouldn't happen")
end
Threads.atomic_add!(y, 1)
end
true
catch ErrorException
false
end

@test try
x = Threads.Atomic{Int64}(0)
y = Threads.Atomic{Int64}(0)
@tasks for i in 1:20
@set ntasks = 20

Threads.atomic_add!(x, 1)
@barrier
Threads.atomic_add!(x, 1)
@barrier
if x[] < 40 && y[] > 0 # x hasn't reached 20 yet and y is already > 0
error("shouldn't happen")
end
Threads.atomic_add!(y, 1)
end
true
catch ErrorException
false
end
end

# Todo way more testing, and easier tests to deal with
Loading