Skip to content

Commit

Permalink
update implementations and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carstenbauer committed Feb 7, 2024
1 parent c17d6e7 commit 5c519c9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 45 deletions.
57 changes: 25 additions & 32 deletions src/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,25 @@ using Base.Threads: nthreads, @threads
using BangBang: BangBang, append!!

function tmapreduce(f, op, Arrs...;
nchunks::Int=nthreads(),
split::Symbol=:batch,
schedule::Symbol=:dynamic,
scheduler::Scheduler=DynamicScheduler(),
outputtype::Type=Any,
mapreduce_kwargs...)

min_kwarg_len = haskey(mapreduce_kwargs, :init) ? 1 : 0
if length(mapreduce_kwargs) > min_kwarg_len
tmapreduce_kwargs_err(;mapreduce_kwargs...)
end
if schedule === :dynamic
_tmapreduce(f, op, Arrs, outputtype, nchunks, split, :default, mapreduce_kwargs)
elseif schedule === :interactive
_tmapreduce(f, op, Arrs, outputtype, nchunks, split, :interactive, mapreduce_kwargs)
elseif schedule === :greedy
_tmapreduce_greedy(f, op, Arrs, outputtype, nchunks, split, mapreduce_kwargs)
elseif schedule === :static
_tmapreduce_static(f, op, Arrs, outputtype, nchunks, split, mapreduce_kwargs)
else
schedule_err(schedule)
end
_tmapreduce(f, op, Arrs, outputtype, scheduler, mapreduce_kwargs)
end
@noinline schedule_err(s) = error(ArgumentError("Invalid schedule option: $s, expected :dynamic, :interactive, :greedy, or :static."))

@noinline function tmapreduce_kwargs_err(;init=nothing, kwargs...)
error("got unsupported keyword arguments: $((;kwargs...,)) ")
end

treducemap(op, f, A...; kwargs...) = tmapreduce(f, op, A...; kwargs...)

function _tmapreduce(f, op, Arrs, ::Type{OutputType}, nchunks, split, threadpool, mapreduce_kwargs)::OutputType where {OutputType}
function _tmapreduce(f, op, Arrs, ::Type{OutputType}, scheduler::DynamicScheduler, mapreduce_kwargs)::OutputType where {OutputType}
(; nchunks, split, threadpool) = scheduler
check_all_have_same_indices(Arrs)
tasks = map(chunks(first(Arrs); n=nchunks, split)) do inds
args = map(A -> view(A, inds), Arrs)
Expand All @@ -50,14 +38,14 @@ function _tmapreduce(f, op, Arrs, ::Type{OutputType}, nchunks, split, threadpool
mapreduce(fetch, op, tasks)
end

function _tmapreduce_greedy(f, op, Arrs, ::Type{OutputType}, nchunks, split, mapreduce_kwargs)::OutputType where {OutputType}
nchunks > 0 || throw("Error: nchunks must be a positive integer")
function _tmapreduce(f, op, Arrs, ::Type{OutputType}, scheduler::GreedyScheduler, mapreduce_kwargs)::OutputType where {OutputType}
ntasks_desired = scheduler.ntasks
if Base.IteratorSize(first(Arrs)) isa Base.SizeUnknown
ntasks = nchunks
ntasks = ntasks_desired
ch_len = 0
else
check_all_have_same_indices(Arrs)
ntasks = min(length(first(Arrs)), nchunks)
ntasks = min(length(first(Arrs)), ntasks_desired)
ch_len = length(first(Arrs))
end
ch = Channel{Tuple{eltype.(Arrs)...}}(ch_len; spawn=true) do ch
Expand All @@ -73,9 +61,9 @@ function _tmapreduce_greedy(f, op, Arrs, ::Type{OutputType}, nchunks, split, map
mapreduce(fetch, op, tasks; mapreduce_kwargs...)
end

function _tmapreduce_static(f, op, Arrs, ::Type{OutputType}, nchunks, split, mapreduce_kwargs) where {OutputType}
function _tmapreduce(f, op, Arrs, ::Type{OutputType}, scheduler::StaticScheduler, mapreduce_kwargs) where {OutputType}
(; nchunks, split) = scheduler
check_all_have_same_indices(Arrs)
nchunks > 0 || throw("Error: nchunks must be a positive integer")
n = min(nthreads(), nchunks) # We could implement strategies, like round-robin, in the future
tasks = map(enumerate(chunks(first(Arrs); n, split))) do (c, inds)
tid = @inbounds nthtid(c)
Expand Down Expand Up @@ -111,26 +99,31 @@ function tmap(f, ::Type{T}, A::AbstractArray, _Arrs::AbstractArray...; kwargs...
tmap!(f, similar(A, T), Arrs...; kwargs...)
end

function tmap(f, A::AbstractArray, _Arrs::AbstractArray...; nchunks::Int=nthreads(), schedule=:dynamic, kwargs...)
function tmap(f, A::AbstractArray, _Arrs::AbstractArray...; scheduler::Scheduler=DynamicScheduler(), kwargs...)
if scheduler isa GreedyScheduler
error("Greedy scheduler isn't supported with `tmap` unless you provide an `OutputElementType` argument, since the greedy schedule requires a commutative reducing operator.")
end
(; nchunks, split) = scheduler
if split != :batch
error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $scheduler)")
end
Arrs = (A, _Arrs...)
check_all_have_same_indices(Arrs)
the_chunks = collect(chunks(A; n=nchunks))
if schedule == :greedy
error("Greedy schedules are not supported with `tmap` unless you provide an `OutputElementType` argument, since the greedy schedule requires a commutative reducing operator.")
end
# It's vital that we force split=:batch here because we're not doing a commutative operation!
v = tmapreduce(append!!, the_chunks; kwargs..., nchunks, split=:batch) do inds
chunk_idcs = collect(chunks(A; n=nchunks))
v = tmapreduce(append!!, chunk_idcs; scheduler, kwargs...) do inds
args = map(A -> @view(A[inds]), Arrs)
map(f, args...)
end
reshape(v, size(A)...)
end

@propagate_inbounds function tmap!(f, out, A::AbstractArray, _Arrs::AbstractArray...; kwargs...)
@propagate_inbounds function tmap!(f, out, A::AbstractArray, _Arrs::AbstractArray...; scheduler::Scheduler=DynamicScheduler(), kwargs...)
if hasfield(typeof(scheduler), :split) && scheduler.split != :batch
error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $scheduler)")
end
Arrs = (A, _Arrs...)
@boundscheck check_all_have_same_indices((out, Arrs...))
# It's vital that we force split=:batch here because we're not doing a commutative operation!
tforeach(eachindex(out); kwargs..., split=:batch) do i
tforeach(eachindex(out); scheduler, kwargs...) do i
args = map(A -> @inbounds(A[i]), Arrs)
res = f(args...)
out[i] = res
Expand Down
9 changes: 6 additions & 3 deletions src/schedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ they can migrate between threads.
- `split::Symbol` (default `:batch`):
* Determines how the collection is divided into chunks. By default, each chunk consists of contiguous elements.
* See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options.
- `threadpool::Symbol` (default `:default`):
* Possible options are `:default` and `:interactive`.
"""
Base.@kwdef struct DynamicScheduler <: Scheduler
nchunks::Int = 2 * nthreads() # a multiple of nthreads to enable load balancing
nchunks::Int = 2 * nthreads() # a multiple of nthreads to enable load balancing
split::Symbol = :batch
threadpool::Symbol = :default
end

"""
Expand All @@ -43,7 +46,7 @@ they are guaranteed to stay on the assigned threads (**no task migration**).
* See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options.
"""
Base.@kwdef struct StaticScheduler <: Scheduler
nchunks::Int = nthreads()
nchunks::Int = nthreads()
split::Symbol = :batch
end

Expand All @@ -60,7 +63,7 @@ some additional overhead.
* Setting `nchunks < nthreads()` is an effective way to use only a subset of the available threads.
"""
Base.@kwdef struct GreedyScheduler <: Scheduler
ntasks::Int = nthreads()
ntasks::Int = nthreads()
end

end # module
25 changes: 15 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,38 @@ sets_to_test = [
@testset "Basics" begin
for (; ~, f, op, itrs, init) sets_to_test
@testset "f=$f, op=$op, itrs::$(typeof(itrs))" begin
@testset for schedule (:static, :dynamic, :interactive, :greedy)
@testset for sched (StaticScheduler, DynamicScheduler, GreedyScheduler)
@testset for split (:batch, :scatter)
for nchunks (1, 2, 6)
rand() < 0.25 && continue # we don't really want full coverage here

kwargs = (; schedule, split, nchunks)
if (split == :scatter || schedule == :greedy) || op (vcat, *)
if sched == GreedyScheduler
scheduler = sched(; ntasks=nchunks)
else
scheduler = sched(; nchunks, split)
end

kwargs = (; scheduler)
if (split == :scatter || sched == GreedyScheduler) || op (vcat, *)
# scatter and greedy only works for commutative operators!
else
mapreduce_f_op_itr = mapreduce(f, op, itrs...)
@test tmapreduce(f, op, itrs...; init, kwargs...) ~ mapreduce_f_op_itr
@test treducemap(op, f, itrs...; init, kwargs...) ~ mapreduce_f_op_itr
@test treduce(op, f.(itrs...); init, kwargs...) ~ mapreduce_f_op_itr
end


split == :scatter && continue
map_f_itr = map(f, itrs...)
@test all(tmap(f, Any, itrs...; kwargs...) .~ map_f_itr)
@test all(tcollect(Any, (f(x...) for x in collect(zip(itrs...))); kwargs...) .~ map_f_itr)
@test all(tcollect(Any, f.(itrs...); kwargs...) .~ map_f_itr)

RT = Core.Compiler.return_type(f, Tuple{eltype.(itrs)...})

@test tmap(f, RT, itrs...; kwargs...) ~ map_f_itr
@test tcollect(RT, (f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr
@test tcollect(RT, f.(itrs...); kwargs...) ~ map_f_itr
if schedule !== :greedy

if sched !== GreedyScheduler
@test tmap(f, itrs...; kwargs...) ~ map_f_itr
@test tcollect((f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr
@test tcollect(f.(itrs...); kwargs...) ~ map_f_itr
Expand Down

0 comments on commit 5c519c9

Please sign in to comment.