Skip to content

Commit

Permalink
fix #107
Browse files Browse the repository at this point in the history
  • Loading branch information
carstenbauer committed May 29, 2024
1 parent 0f9c12b commit 1ac9682
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 14 deletions.
42 changes: 28 additions & 14 deletions src/macro_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@ using OhMyThreads.Tools: OnlyOneRegion, try_enter!
using OhMyThreads.Tools: SimpleBarrier
using OhMyThreads: OhMyThreads

function _is_special_macro_expr(arg;
lookfor = ("@set", "@local", "@only_one", "@one_by_one", "@barrier"))
if !(arg isa Expr)
return false
end
lookfor_symbols = Symbol.(lookfor)
if arg.head == :macrocall
if arg.args[1] isa Symbol && arg.args[1] in lookfor_symbols
# support, e.g., @set
return true
elseif arg.args[1] isa Expr && arg.args[1].head == Symbol(".")
# support, e.g., OhMyThreads.@set
x = arg.args[1]
if x.args[1] == Symbol("OhMyThreads") && x.args[2] isa QuoteNode &&
x.args[2].value in lookfor_symbols
return true
end
end
end
return false
end

function tasks_macro(forex; __module__)
if forex.head != :for
throw(ErrorException("Expected a for loop after `@tasks`."))
Expand All @@ -24,15 +46,7 @@ function tasks_macro(forex; __module__)
# Escape everything in the loop body that is not used in conjuction with one of our
# "macros", e.g. @set or @local. Code inside of these macro blocks will be escaped by
# the respective "macro" handling functions below.
for i in findall(forbody.args) do arg
!(arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set")) &&
!(arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@local")) &&
!(arg isa Expr && arg.head == :macrocall &&
arg.args[1] == Symbol("@only_one")) &&
!(arg isa Expr && arg.head == :macrocall &&
arg.args[1] == Symbol("@one_by_one")) &&
!(arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@barrier"))
end
for i in findall(!_is_special_macro_expr, forbody.args)
forbody.args[i] = esc(forbody.args[i])
end

Expand Down Expand Up @@ -138,7 +152,7 @@ function _maybe_handle_atlocal_block!(args)
locals_before = nothing
local_inner = nothing
tlsidx = findfirst(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@local")
_is_special_macro_expr(arg; lookfor = (Symbol("@local"),))
end
if !isnothing(tlsidx)
locals_before, local_inner = _unfold_atlocal_block(args[tlsidx].args[3])
Expand Down Expand Up @@ -198,7 +212,7 @@ end

function _maybe_handle_atset_block!(settings, args)
idcs = findall(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set")
_is_special_macro_expr(arg; lookfor = (Symbol("@set"),))
end
isnothing(idcs) && return # no @set block found
for i in idcs
Expand Down Expand Up @@ -240,7 +254,7 @@ end

function _maybe_handle_atonlyone_blocks!(args)
idcs = findall(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@only_one")
_is_special_macro_expr(arg; lookfor = (Symbol("@only_one"),))
end
isnothing(idcs) && return # no @only_one blocks
setup_onlyone_blocks = quote end
Expand All @@ -260,7 +274,7 @@ end

function _maybe_handle_atonebyone_blocks!(args)
idcs = findall(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@one_by_one")
_is_special_macro_expr(arg; lookfor = (Symbol("@one_by_one"),))
end
isnothing(idcs) && return # no @one_by_one blocks
setup_onebyone_blocks = quote end
Expand All @@ -280,7 +294,7 @@ end

function _maybe_handle_atbarriers!(args, settings)
idcs = findall(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@barrier")
_is_special_macro_expr(arg; lookfor = (Symbol("@barrier"),))
end
isnothing(idcs) && return # no @barrier found
setup_barriers = quote end
Expand Down
68 changes: 68 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -597,4 +597,72 @@ end;
end
end

@testset "verbose special macro usage" begin
# OhMyThreads.@set
@test @tasks(for i in 1:3
OhMyThreads.@set reducer = (+)
i
end) == 6
@test @tasks(for i in 1:3
OhMyThreads.@set begin
reducer = (+)
end
i
end) == 6
# OhMyThreads.@local
ntd = 2 * Threads.nthreads()
@tasks(for i in 1:ntd
OhMyThreads.@local x::Ref{Int64} = Ref(0)
OhMyThreads.@set begin
reducer = (+)
scheduler = :static
end
x[] += 1
x[]
end) == @tasks(for i in 1:ntd
@local x::Ref{Int64} = Ref(0)
@set begin
reducer = (+)
scheduler = :static
end
x[] += 1
x[]
end)
# OhMyThreads.@only_one
x = 0
y = 0
try
@tasks for i in 1:10
OhMyThreads.@set ntasks = 10

y += 1 # not safe (race condition)
OhMyThreads.@only_one begin
x += 1 # parallel-safe because only a single task will execute this
end
end
@test x == 1 # only a single task should have incremented x
catch ErrorException
@test false
end
# OhMyThreads.@one_by_one
test_f = () -> begin
sao = SingleAccessOnly()
x = 0
y = 0
@tasks for i in 1:10
OhMyThreads.@set ntasks = 10

y += 1 # not safe (race condition)
OhMyThreads.@one_by_one begin
x += 1 # parallel-safe because inside of one_by_one region
acquire(sao) do
sleep(0.01)
end
end
end
return x
end
@test test_f() == 10
end

# Todo way more testing, and easier tests to deal with

0 comments on commit 1ac9682

Please sign in to comment.