Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add and
Browse files Browse the repository at this point in the history
carstenbauer committed Jan 29, 2024
1 parent 784afae commit f78affe
Showing 3 changed files with 79 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.vscode
7 changes: 4 additions & 3 deletions src/ThreadsBasics.jl
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ A multithreaded function like `Base.mapreduce`. Perform a reduction over `A`, ap
function `f` to each element, and then combining them with the two-argument function `op`. `op` **must** be an
[associative](https://en.wikipedia.org/wiki/Associative_property) function, in the sense that
`op(a, op(b, c)) ≈ op(op(a, b), c)`. If `op` is not (approximately) associative, you will get undefined
results.
results.
For a very well known example of `mapreduce`, `sum(f, A)` is equivalent to `mapreduce(f, +, A)`. Doing
@@ -130,7 +130,7 @@ Apply `f` to each element of `A` on multiple parallel tasks, and return `nothing
function tforeach end

"""
tmap(f, ::Type{OutputType}, A::AbstractArray;
tmap(f, ::Type{OutputType}, A::AbstractArray;
nchunks::Int = 2 * nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic)
@@ -147,7 +147,7 @@ A multithreaded function like `Base.map`. Create a new container `similar` to `A
function tmap end

"""
tmap!(f, out, A::AbstractArray;
tmap!(f, out, A::AbstractArray;
nchunks::Int = 2 * nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic)
@@ -165,6 +165,7 @@ function tmap! end


include("implementation.jl")
include("tools.jl")


end # module ThreadsBasics
74 changes: 74 additions & 0 deletions src/tools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
@tspawnat tid -> task
Mimics `Threads.@spawn`, but assigns the task to thread `tid` (with `sticky = true`).
# Example
```julia
julia> t = @tspawnat 4 Threads.threadid();
julia> fetch(t)
4
```
"""
macro tspawnat(thrdid, expr)
letargs = Base._lift_one_interp!(expr)

thunk = esc(:(() -> ($expr)))
var = esc(Base.sync_varname)
tid = esc(thrdid)
@static if VERSION < v"1.9-"
nt = :(Threads.nthreads())
else
nt = :(Threads.maxthreadid())
end
quote
if $tid < 1 || $tid > $nt
throw(ArgumentError("Invalid thread id ($($tid)). Must be between in " *
"1:(total number of threads), i.e. $(1:$nt)."))
end
let $(letargs...)
local task = Task($thunk)
task.sticky = true
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid - 1)
if $(Expr(:islocal, var))
put!($var, task)
end
schedule(task)
task
end
end
end

"""
threadids(threadpool = :default) -> Vector{Int}
Returns the thread ids of the threads in the given threadpool.
Supported values for `threadpool` are `:default`, `:interactive`, and `:all`, where the latter
provides all thread ids with default threads coming first.
"""
function threadids(threadpool = :default)::Vector{Int}
@static if VERSION < v"1.9-"
return collect(1:nthreads())
else
if threadpool == :all
nt = nthreads(:default) + nthreads(:interactive)
tids_default = filter(i -> Threads.threadpool(i) == :default, 1:Threads.maxthreadid())
tids_interactive = filter(i -> Threads.threadpool(i) == :interactive, 1:Threads.maxthreadid())
tids = vcat(tids_default, tids_interactive)
else
nt = nthreads(threadpool)
tids = filter(i -> Threads.threadpool(i) == threadpool, 1:Threads.maxthreadid())
end

if nt != length(tids)
# IJulia manually adds a heartbeat thread that mus be ignored...
# see https://github.com/JuliaLang/IJulia.jl/issues/1072
# Currently, we just assume that it is the last thread.
# Might not be safe, in particular not once users can dynamically add threads
# in the future.
pop!(tids)
end
return tids
end
end

0 comments on commit f78affe

Please sign in to comment.