From 732d716996d59fa728c1d6d8f3873af8754a2637 Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Wed, 21 Feb 2024 13:00:34 +0100 Subject: [PATCH] try implement threaded macro --- src/OhMyThreads.jl | 2 + src/macro.jl | 91 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 src/macro.jl diff --git a/src/OhMyThreads.jl b/src/OhMyThreads.jl index f5571eb1..580fa952 100644 --- a/src/OhMyThreads.jl +++ b/src/OhMyThreads.jl @@ -201,6 +201,8 @@ using .Schedulers: Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler, SpawnAllScheduler include("implementation.jl") +include("macro.jl") +export @threaded export treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect export Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler, SpawnAllScheduler diff --git a/src/macro.jl b/src/macro.jl new file mode 100644 index 00000000..462e8891 --- /dev/null +++ b/src/macro.jl @@ -0,0 +1,91 @@ +function _kwarg_to_tuple(ex) + ex.head != :(=) && + throw(ArgumentError("Invalid keyword argument. Doesn't contain '='.")) + name, val = ex.args + !(name isa Symbol) && + throw(ArgumentError("First part of keyword argument isn't a symbol.")) + val isa QuoteNode && (val = val.value) + (name, val) +end + +macro threaded(args...) + forex = last(args) + kwexs = args[begin:(end - 1)] + scheduler = DynamicScheduler() + reducer = nothing + for ex in kwexs + name, val = _kwarg_to_tuple(ex) + if name == :scheduler + if val == :dynamic + scheduler = DynamicScheduler() + elseif val == :static + scheduler = StaticScheduler() + elseif val == :greedy + scheduler = GreedyScheduler() + else + scheduler = val + end + elseif name == :reduce + reducer = val + else + throw(ArgumentError("Unknown keyword argument: $name")) + end + end + + if forex.head != :for + throw(ErrorException("Expected for loop after `@threaded`.")) + else + it = forex.args[1] + itvar = it.args[1] + itrng = it.args[2] + forbody = forex.args[2] + end + + lbi = findfirst(forbody.args) do arg + arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@tasklocal") + end + if !isnothing(lbi) + assignment_ex = forbody.args[lbi].args[3] + if assignment_ex.head != :(=) + throw(ErrorException("Wrong usage of @tasklocal. Expected assignment, e.g. `A::Matrix{Float} = rand(2,2)`.")) + end + left_ex = assignment_ex.args[1] + if left_ex isa Symbol || left_ex.head != :(::) + throw(ErrorException("Wrong usage of @tasklocal. Expected typed assignment, e.g. `A::Matrix{Float} = rand(2,2)`.")) + end + tls_sym = left_ex.args[1] + tls_type = left_ex.args[2] + tls_def = assignment_ex.args[2] + tls_storage = gensym() + tlsinit = quote + $(tls_storage) = OhMyThreads.TaskLocalValue{$tls_type}(() -> $(tls_def)) + end + tlsblock = quote + $(tls_sym) = $(tls_storage)[] + end + deleteat!(forbody.args, lbi) + else + tlsinit = nothing + tlsblock = nothing + end + + q = if isnothing(reducer) + quote + $(tlsinit) + OhMyThreads.tforeach($(itrng); scheduler = $(scheduler)) do $(itvar) + $(tlsblock) + $(forbody) + end + end + else + quote + $(tlsinit) + OhMyThreads.tmapreduce( + $(reducer), $(itrng); scheduler = $(scheduler)) do $(itvar) + $(tlsblock) + $(forbody) + end + end + end + esc(q) +end