From f78affe8adcdfcac8acc59480c6d4eae7bb02d69 Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Mon, 29 Jan 2024 10:00:16 +0100 Subject: [PATCH] add and --- .gitignore | 1 + src/ThreadsBasics.jl | 7 +++-- src/tools.jl | 74 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 .gitignore create mode 100644 src/tools.jl diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..722d5e71 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.vscode diff --git a/src/ThreadsBasics.jl b/src/ThreadsBasics.jl index 55a9a1f4..5602a907 100644 --- a/src/ThreadsBasics.jl +++ b/src/ThreadsBasics.jl @@ -18,7 +18,7 @@ A multithreaded function like `Base.mapreduce`. Perform a reduction over `A`, ap function `f` to each element, and then combining them with the two-argument function `op`. `op` **must** be an [associative](https://en.wikipedia.org/wiki/Associative_property) function, in the sense that `op(a, op(b, c)) ≈ op(op(a, b), c)`. If `op` is not (approximately) associative, you will get undefined -results. +results. For a very well known example of `mapreduce`, `sum(f, A)` is equivalent to `mapreduce(f, +, A)`. Doing @@ -130,7 +130,7 @@ Apply `f` to each element of `A` on multiple parallel tasks, and return `nothing function tforeach end """ - tmap(f, ::Type{OutputType}, A::AbstractArray; + tmap(f, ::Type{OutputType}, A::AbstractArray; nchunks::Int = 2 * nthreads(), split::Symbol = :batch, schedule::Symbol =:dynamic) @@ -147,7 +147,7 @@ A multithreaded function like `Base.map`. Create a new container `similar` to `A function tmap end """ - tmap!(f, out, A::AbstractArray; + tmap!(f, out, A::AbstractArray; nchunks::Int = 2 * nthreads(), split::Symbol = :batch, schedule::Symbol =:dynamic) @@ -165,6 +165,7 @@ function tmap! end include("implementation.jl") +include("tools.jl") end # module ThreadsBasics diff --git a/src/tools.jl b/src/tools.jl new file mode 100644 index 00000000..215e806c --- /dev/null +++ b/src/tools.jl @@ -0,0 +1,74 @@ +""" + @tspawnat tid -> task +Mimics `Threads.@spawn`, but assigns the task to thread `tid` (with `sticky = true`). + +# Example +```julia +julia> t = @tspawnat 4 Threads.threadid(); + +julia> fetch(t) +4 +``` +""" +macro tspawnat(thrdid, expr) + letargs = Base._lift_one_interp!(expr) + + thunk = esc(:(() -> ($expr))) + var = esc(Base.sync_varname) + tid = esc(thrdid) + @static if VERSION < v"1.9-" + nt = :(Threads.nthreads()) + else + nt = :(Threads.maxthreadid()) + end + quote + if $tid < 1 || $tid > $nt + throw(ArgumentError("Invalid thread id ($($tid)). Must be between in " * + "1:(total number of threads), i.e. $(1:$nt).")) + end + let $(letargs...) + local task = Task($thunk) + task.sticky = true + ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid - 1) + if $(Expr(:islocal, var)) + put!($var, task) + end + schedule(task) + task + end + end +end + +""" +threadids(threadpool = :default) -> Vector{Int} + +Returns the thread ids of the threads in the given threadpool. + +Supported values for `threadpool` are `:default`, `:interactive`, and `:all`, where the latter +provides all thread ids with default threads coming first. +""" +function threadids(threadpool = :default)::Vector{Int} + @static if VERSION < v"1.9-" + return collect(1:nthreads()) + else + if threadpool == :all + nt = nthreads(:default) + nthreads(:interactive) + tids_default = filter(i -> Threads.threadpool(i) == :default, 1:Threads.maxthreadid()) + tids_interactive = filter(i -> Threads.threadpool(i) == :interactive, 1:Threads.maxthreadid()) + tids = vcat(tids_default, tids_interactive) + else + nt = nthreads(threadpool) + tids = filter(i -> Threads.threadpool(i) == threadpool, 1:Threads.maxthreadid()) + end + + if nt != length(tids) + # IJulia manually adds a heartbeat thread that mus be ignored... + # see https://github.com/JuliaLang/IJulia.jl/issues/1072 + # Currently, we just assume that it is the last thread. + # Might not be safe, in particular not once users can dynamically add threads + # in the future. + pop!(tids) + end + return tids + end +end