diff --git a/src/OhMyThreads.jl b/src/OhMyThreads.jl index cf8da9a4..568ba957 100644 --- a/src/OhMyThreads.jl +++ b/src/OhMyThreads.jl @@ -20,7 +20,7 @@ using .Schedulers: Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler SerialScheduler include("implementation.jl") -export @tasks, @set, @local +export @tasks, @set, @local, @barrier export treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect export Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler, SerialScheduler diff --git a/src/macro_impl.jl b/src/macro_impl.jl index b4e74a58..d2482a2e 100644 --- a/src/macro_impl.jl +++ b/src/macro_impl.jl @@ -1,3 +1,5 @@ +using OhMyThreads.Tools: SimpleBarrier + function tasks_macro(forex) if forex.head != :for throw(ErrorException("Expected a for loop after `@tasks`.")) @@ -20,6 +22,7 @@ function tasks_macro(forex) locals_before, locals_names = _maybe_handle_atlocal_block!(forbody.args) tls_names = isnothing(locals_before) ? [] : map(x -> x.args[1], locals_before) _maybe_handle_atset_block!(settings, forbody.args) + setup_barrier = _maybe_handle_atbarrier!(forbody.args) forbody = esc(forbody) itrng = esc(itrng) @@ -39,6 +42,7 @@ function tasks_macro(forex) end q = if isgiven(settings.reducer) quote + $setup_barrier $make_mapping_function tmapreduce(mapping_function, $(settings.reducer), $(itrng)) @@ -46,12 +50,14 @@ function tasks_macro(forex) elseif isgiven(settings.collect) maybe_warn_useless_init(settings) quote + $setup_barrier $make_mapping_function tmap(mapping_function, $(itrng)) end else maybe_warn_useless_init(settings) quote + $setup_barrier $make_mapping_function tforeach(mapping_function, $(itrng)) end @@ -68,7 +74,7 @@ function tasks_macro(forex) for (k, v) in settings.kwargs push!(kwexpr.args, Expr(:kw, k, v)) end - insert!(q.args[4].args, 2, kwexpr) + insert!(q.args[6].args, 2, kwexpr) # wrap everything in a let ... end block # and, potentially, define the `TaskLocalValue`s. @@ -151,16 +157,15 @@ function _atlocal_assign_to_exprs(ex) tls_type = esc(left_ex.args[2]) local_before = :($(tl_storage) = TaskLocalValue{$tls_type}(() -> $(tls_def))) else - tls_sym = esc(left_ex) + tls_sym = esc(left_ex) local_before = :($(tl_storage) = let f = () -> $(tls_def) - TaskLocalValue{Core.Compiler.return_type(f, Tuple{})}(f) - end) + TaskLocalValue{Core.Compiler.return_type(f, Tuple{})}(f) + end) end local_name = :($(tls_sym)) return local_before, local_name end - function _maybe_handle_atset_block!(settings, args) idcs = findall(args) do arg arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set") @@ -201,3 +206,19 @@ function _handle_atset_single_assign!(settings, ex) push!(settings.kwargs, sym => esc(def)) end end + +function _maybe_handle_atbarrier!(args) + idcs = findall(args) do arg + arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@barrier") + end + isnothing(idcs) && return # no @barrier + setup_barrier = quote end + for i in idcs + @gensym barrier + # TODO: Problem... we need to know the number of tasks but I think we can't know that... + init_barrier_ex = esc(:($(barrier) = $(SimpleBarrier(10)))) # drop escape once merged with PR#93 + push!(setup_barrier.args, init_barrier_ex) + args[i] = :(wait($(barrier))) + end + return setup_barrier +end diff --git a/src/macros.jl b/src/macros.jl index 2c868144..daf8319c 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -151,3 +151,31 @@ end error("The @local macro may only be used inside of a @tasks block.") end end + +""" + @barrier + +This can be used inside a `@tasks for ... end` block to mark a barrier at which all tasks +will wait for each other (synchronize) before moving on. + +## Example + +```julia +using OhMyThreads: @tasks + +tstart = time_ns() * 1e-9 +now = () -> time_ns() * 1e-9 - tstart + +@tasks for i in 1:10 + @set ntasks = 10 + + sleep(i * 0.2) + println(i, ": arriving at barrier (", now(), ")") + @barrier + println(i, ": moving on (", now(), ")") +end +``` +""" +macro barrier() + error("The @barrier macro may only be used inside of a @tasks block.") +end diff --git a/src/tools.jl b/src/tools.jl index 0e8983e0..c04d2716 100644 --- a/src/tools.jl +++ b/src/tools.jl @@ -24,4 +24,54 @@ Return a `UInt` identifier for the current running [Task](https://docs.julialang """ taskid() = objectid(current_task()) +""" +SimpleBarrier(n::Integer) + +Simple reusable barrier for `n` parallel tasks. + +Given `b = SimpleBarrier(n)` and `n` parallel tasks, each task that calls +`wait(b)` will block until the other `n-1` tasks have called `wait(b)` as well. + +## Example +``` +using OhMyThreads.Tools: SimpleBarrier, @tasks + +n = nthreads() +barrier = SimpleBarrier(n) +@tasks for i in 1:n + @set ntasks = n + + println("A") + wait(barrier) # synchronize all tasks + println("B") + wait(barrier) # synchronize all tasks (reusable) + println("C") +end +``` +""" +struct SimpleBarrier + n::Int64 + c::Threads.Condition + cnt::Base.RefValue{Int64} + + function SimpleBarrier(n::Integer) + new(n, Threads.Condition(), Base.RefValue{Int64}(0)) + end +end + +function Base.wait(b::SimpleBarrier) + lock(b.c) + try + b.cnt[] += 1 + if b.cnt[] == b.n + b.cnt[] = 0 + notify(b.c) + else + wait(b.c) + end + finally + unlock(b.c) + end +end + end # Tools