Skip to content

Commit

Permalink
attempt (with a big problem, not fully working)
Browse files Browse the repository at this point in the history
  • Loading branch information
carstenbauer committed Mar 18, 2024
1 parent 3513109 commit cd08e66
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/OhMyThreads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using .Schedulers: Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler
SerialScheduler
include("implementation.jl")

export @tasks, @set, @local
export @tasks, @set, @local, @barrier
export treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect
export Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler, SerialScheduler

Expand Down
31 changes: 26 additions & 5 deletions src/macro_impl.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using OhMyThreads.Tools: SimpleBarrier

function tasks_macro(forex)
if forex.head != :for
throw(ErrorException("Expected a for loop after `@tasks`."))
Expand All @@ -20,6 +22,7 @@ function tasks_macro(forex)
locals_before, locals_names = _maybe_handle_atlocal_block!(forbody.args)
tls_names = isnothing(locals_before) ? [] : map(x -> x.args[1], locals_before)
_maybe_handle_atset_block!(settings, forbody.args)
setup_barrier = _maybe_handle_atbarrier!(forbody.args)

forbody = esc(forbody)
itrng = esc(itrng)
Expand All @@ -39,19 +42,22 @@ function tasks_macro(forex)
end
q = if isgiven(settings.reducer)
quote
$setup_barrier
$make_mapping_function
tmapreduce(mapping_function, $(settings.reducer),
$(itrng))
end
elseif isgiven(settings.collect)
maybe_warn_useless_init(settings)
quote
$setup_barrier
$make_mapping_function
tmap(mapping_function, $(itrng))
end
else
maybe_warn_useless_init(settings)
quote
$setup_barrier
$make_mapping_function
tforeach(mapping_function, $(itrng))
end
Expand All @@ -68,7 +74,7 @@ function tasks_macro(forex)
for (k, v) in settings.kwargs
push!(kwexpr.args, Expr(:kw, k, v))
end
insert!(q.args[4].args, 2, kwexpr)
insert!(q.args[6].args, 2, kwexpr)

# wrap everything in a let ... end block
# and, potentially, define the `TaskLocalValue`s.
Expand Down Expand Up @@ -151,16 +157,15 @@ function _atlocal_assign_to_exprs(ex)
tls_type = esc(left_ex.args[2])
local_before = :($(tl_storage) = TaskLocalValue{$tls_type}(() -> $(tls_def)))
else
tls_sym = esc(left_ex)
tls_sym = esc(left_ex)
local_before = :($(tl_storage) = let f = () -> $(tls_def)
TaskLocalValue{Core.Compiler.return_type(f, Tuple{})}(f)
end)
TaskLocalValue{Core.Compiler.return_type(f, Tuple{})}(f)
end)
end
local_name = :($(tls_sym))
return local_before, local_name
end


function _maybe_handle_atset_block!(settings, args)
idcs = findall(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set")
Expand Down Expand Up @@ -201,3 +206,19 @@ function _handle_atset_single_assign!(settings, ex)
push!(settings.kwargs, sym => esc(def))
end
end

function _maybe_handle_atbarrier!(args)
idcs = findall(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@barrier")
end
isnothing(idcs) && return # no @barrier
setup_barrier = quote end
for i in idcs
@gensym barrier
# TODO: Problem... we need to know the number of tasks but I think we can't know that...
init_barrier_ex = esc(:($(barrier) = $(SimpleBarrier(10)))) # drop escape once merged with PR#93
push!(setup_barrier.args, init_barrier_ex)
args[i] = :(wait($(barrier)))
end
return setup_barrier
end
28 changes: 28 additions & 0 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,31 @@ end
error("The @local macro may only be used inside of a @tasks block.")
end
end

"""
@barrier
This can be used inside a `@tasks for ... end` block to mark a barrier at which all tasks
will wait for each other (synchronize) before moving on.
## Example
```julia
using OhMyThreads: @tasks
tstart = time_ns() * 1e-9
now = () -> time_ns() * 1e-9 - tstart
@tasks for i in 1:10
@set ntasks = 10
sleep(i * 0.2)
println(i, ": arriving at barrier (", now(), ")")
@barrier
println(i, ": moving on (", now(), ")")
end
```
"""
macro barrier()
error("The @barrier macro may only be used inside of a @tasks block.")
end
50 changes: 50 additions & 0 deletions src/tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,54 @@ Return a `UInt` identifier for the current running [Task](https://docs.julialang
"""
taskid() = objectid(current_task())

"""
SimpleBarrier(n::Integer)
Simple reusable barrier for `n` parallel tasks.
Given `b = SimpleBarrier(n)` and `n` parallel tasks, each task that calls
`wait(b)` will block until the other `n-1` tasks have called `wait(b)` as well.
## Example
```
using OhMyThreads.Tools: SimpleBarrier, @tasks
n = nthreads()
barrier = SimpleBarrier(n)
@tasks for i in 1:n
@set ntasks = n
println("A")
wait(barrier) # synchronize all tasks
println("B")
wait(barrier) # synchronize all tasks (reusable)
println("C")
end
```
"""
struct SimpleBarrier
n::Int64
c::Threads.Condition
cnt::Base.RefValue{Int64}

function SimpleBarrier(n::Integer)
new(n, Threads.Condition(), Base.RefValue{Int64}(0))
end
end

function Base.wait(b::SimpleBarrier)
lock(b.c)
try
b.cnt[] += 1
if b.cnt[] == b.n
b.cnt[] = 0
notify(b.c)
else
wait(b.c)
end
finally
unlock(b.c)
end
end

end # Tools

0 comments on commit cd08e66

Please sign in to comment.