diff --git a/src/implementation.jl b/src/implementation.jl index b2dae192..8768b482 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -11,9 +11,7 @@ using Base.Threads: nthreads, @threads using BangBang: BangBang, append!! function tmapreduce(f, op, Arrs...; - nchunks::Int=nthreads(), - split::Symbol=:batch, - schedule::Symbol=:dynamic, + scheduler::Scheduler=DynamicScheduler(), outputtype::Type=Any, mapreduce_kwargs...) @@ -21,19 +19,8 @@ function tmapreduce(f, op, Arrs...; if length(mapreduce_kwargs) > min_kwarg_len tmapreduce_kwargs_err(;mapreduce_kwargs...) end - if schedule === :dynamic - _tmapreduce(f, op, Arrs, outputtype, nchunks, split, :default, mapreduce_kwargs) - elseif schedule === :interactive - _tmapreduce(f, op, Arrs, outputtype, nchunks, split, :interactive, mapreduce_kwargs) - elseif schedule === :greedy - _tmapreduce_greedy(f, op, Arrs, outputtype, nchunks, split, mapreduce_kwargs) - elseif schedule === :static - _tmapreduce_static(f, op, Arrs, outputtype, nchunks, split, mapreduce_kwargs) - else - schedule_err(schedule) - end + _tmapreduce(f, op, Arrs, outputtype, scheduler, mapreduce_kwargs) end -@noinline schedule_err(s) = error(ArgumentError("Invalid schedule option: $s, expected :dynamic, :interactive, :greedy, or :static.")) @noinline function tmapreduce_kwargs_err(;init=nothing, kwargs...) error("got unsupported keyword arguments: $((;kwargs...,)) ") @@ -41,7 +28,8 @@ end treducemap(op, f, A...; kwargs...) = tmapreduce(f, op, A...; kwargs...) -function _tmapreduce(f, op, Arrs, ::Type{OutputType}, nchunks, split, threadpool, mapreduce_kwargs)::OutputType where {OutputType} +function _tmapreduce(f, op, Arrs, ::Type{OutputType}, scheduler::DynamicScheduler, mapreduce_kwargs)::OutputType where {OutputType} + (; nchunks, split, threadpool) = scheduler check_all_have_same_indices(Arrs) tasks = map(chunks(first(Arrs); n=nchunks, split)) do inds args = map(A -> view(A, inds), Arrs) @@ -50,14 +38,14 @@ function _tmapreduce(f, op, Arrs, ::Type{OutputType}, nchunks, split, threadpool mapreduce(fetch, op, tasks) end -function _tmapreduce_greedy(f, op, Arrs, ::Type{OutputType}, nchunks, split, mapreduce_kwargs)::OutputType where {OutputType} - nchunks > 0 || throw("Error: nchunks must be a positive integer") +function _tmapreduce(f, op, Arrs, ::Type{OutputType}, scheduler::GreedyScheduler, mapreduce_kwargs)::OutputType where {OutputType} + ntasks_desired = scheduler.ntasks if Base.IteratorSize(first(Arrs)) isa Base.SizeUnknown - ntasks = nchunks + ntasks = ntasks_desired ch_len = 0 else check_all_have_same_indices(Arrs) - ntasks = min(length(first(Arrs)), nchunks) + ntasks = min(length(first(Arrs)), ntasks_desired) ch_len = length(first(Arrs)) end ch = Channel{Tuple{eltype.(Arrs)...}}(ch_len; spawn=true) do ch @@ -73,9 +61,9 @@ function _tmapreduce_greedy(f, op, Arrs, ::Type{OutputType}, nchunks, split, map mapreduce(fetch, op, tasks; mapreduce_kwargs...) end -function _tmapreduce_static(f, op, Arrs, ::Type{OutputType}, nchunks, split, mapreduce_kwargs) where {OutputType} +function _tmapreduce(f, op, Arrs, ::Type{OutputType}, scheduler::StaticScheduler, mapreduce_kwargs) where {OutputType} + (; nchunks, split) = scheduler check_all_have_same_indices(Arrs) - nchunks > 0 || throw("Error: nchunks must be a positive integer") n = min(nthreads(), nchunks) # We could implement strategies, like round-robin, in the future tasks = map(enumerate(chunks(first(Arrs); n, split))) do (c, inds) tid = @inbounds nthtid(c) @@ -111,26 +99,31 @@ function tmap(f, ::Type{T}, A::AbstractArray, _Arrs::AbstractArray...; kwargs... tmap!(f, similar(A, T), Arrs...; kwargs...) end -function tmap(f, A::AbstractArray, _Arrs::AbstractArray...; nchunks::Int=nthreads(), schedule=:dynamic, kwargs...) +function tmap(f, A::AbstractArray, _Arrs::AbstractArray...; scheduler::Scheduler=DynamicScheduler(), kwargs...) + if scheduler isa GreedyScheduler + error("Greedy scheduler isn't supported with `tmap` unless you provide an `OutputElementType` argument, since the greedy schedule requires a commutative reducing operator.") + end + (; nchunks, split) = scheduler + if split != :batch + error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $scheduler)") + end Arrs = (A, _Arrs...) check_all_have_same_indices(Arrs) - the_chunks = collect(chunks(A; n=nchunks)) - if schedule == :greedy - error("Greedy schedules are not supported with `tmap` unless you provide an `OutputElementType` argument, since the greedy schedule requires a commutative reducing operator.") - end - # It's vital that we force split=:batch here because we're not doing a commutative operation! - v = tmapreduce(append!!, the_chunks; kwargs..., nchunks, split=:batch) do inds + chunk_idcs = collect(chunks(A; n=nchunks)) + v = tmapreduce(append!!, chunk_idcs; scheduler, kwargs...) do inds args = map(A -> @view(A[inds]), Arrs) map(f, args...) end reshape(v, size(A)...) end -@propagate_inbounds function tmap!(f, out, A::AbstractArray, _Arrs::AbstractArray...; kwargs...) +@propagate_inbounds function tmap!(f, out, A::AbstractArray, _Arrs::AbstractArray...; scheduler::Scheduler=DynamicScheduler(), kwargs...) + if hasfield(typeof(scheduler), :split) && scheduler.split != :batch + error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $scheduler)") + end Arrs = (A, _Arrs...) @boundscheck check_all_have_same_indices((out, Arrs...)) - # It's vital that we force split=:batch here because we're not doing a commutative operation! - tforeach(eachindex(out); kwargs..., split=:batch) do i + tforeach(eachindex(out); scheduler, kwargs...) do i args = map(A -> @inbounds(A[i]), Arrs) res = f(args...) out[i] = res diff --git a/src/schedulers.jl b/src/schedulers.jl index 1738551f..e10ae242 100644 --- a/src/schedulers.jl +++ b/src/schedulers.jl @@ -20,10 +20,13 @@ they can migrate between threads. - `split::Symbol` (default `:batch`): * Determines how the collection is divided into chunks. By default, each chunk consists of contiguous elements. * See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. +- `threadpool::Symbol` (default `:default`): + * Possible options are `:default` and `:interactive`. """ Base.@kwdef struct DynamicScheduler <: Scheduler - nchunks::Int = 2 * nthreads() # a multiple of nthreads to enable load balancing + nchunks::Int = 2 * nthreads() # a multiple of nthreads to enable load balancing split::Symbol = :batch + threadpool::Symbol = :default end """ @@ -43,7 +46,7 @@ they are guaranteed to stay on the assigned threads (**no task migration**). * See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. """ Base.@kwdef struct StaticScheduler <: Scheduler - nchunks::Int = nthreads() + nchunks::Int = nthreads() split::Symbol = :batch end @@ -60,7 +63,7 @@ some additional overhead. * Setting `nchunks < nthreads()` is an effective way to use only a subset of the available threads. """ Base.@kwdef struct GreedyScheduler <: Scheduler - ntasks::Int = nthreads() + ntasks::Int = nthreads() end end # module diff --git a/test/runtests.jl b/test/runtests.jl index 244e11c1..f92c1fcd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,13 +11,17 @@ sets_to_test = [ @testset "Basics" begin for (; ~, f, op, itrs, init) ∈ sets_to_test @testset "f=$f, op=$op, itrs::$(typeof(itrs))" begin - @testset for schedule ∈ (:static, :dynamic, :interactive, :greedy) + @testset for sched ∈ (StaticScheduler, DynamicScheduler, GreedyScheduler) @testset for split ∈ (:batch, :scatter) for nchunks ∈ (1, 2, 6) - rand() < 0.25 && continue # we don't really want full coverage here - - kwargs = (; schedule, split, nchunks) - if (split == :scatter || schedule == :greedy) || op ∉ (vcat, *) + if sched == GreedyScheduler + scheduler = sched(; ntasks=nchunks) + else + scheduler = sched(; nchunks, split) + end + + kwargs = (; scheduler) + if (split == :scatter || sched == GreedyScheduler) || op ∉ (vcat, *) # scatter and greedy only works for commutative operators! else mapreduce_f_op_itr = mapreduce(f, op, itrs...) @@ -25,19 +29,20 @@ sets_to_test = [ @test treducemap(op, f, itrs...; init, kwargs...) ~ mapreduce_f_op_itr @test treduce(op, f.(itrs...); init, kwargs...) ~ mapreduce_f_op_itr end - + + split == :scatter && continue map_f_itr = map(f, itrs...) @test all(tmap(f, Any, itrs...; kwargs...) .~ map_f_itr) @test all(tcollect(Any, (f(x...) for x in collect(zip(itrs...))); kwargs...) .~ map_f_itr) @test all(tcollect(Any, f.(itrs...); kwargs...) .~ map_f_itr) - + RT = Core.Compiler.return_type(f, Tuple{eltype.(itrs)...}) - + @test tmap(f, RT, itrs...; kwargs...) ~ map_f_itr @test tcollect(RT, (f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr @test tcollect(RT, f.(itrs...); kwargs...) ~ map_f_itr - - if schedule !== :greedy + + if sched !== GreedyScheduler @test tmap(f, itrs...; kwargs...) ~ map_f_itr @test tcollect((f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr @test tcollect(f.(itrs...); kwargs...) ~ map_f_itr