From 85e968d2c0d4314b4f195af88e26c0ae2e74772f Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Wed, 25 Sep 2024 08:29:40 +0200 Subject: [PATCH] support symbol and split --- src/OhMyThreads.jl | 2 +- src/schedulers.jl | 43 ++++++++++++++++++++++++++++++++----------- test/runtests.jl | 6 +++--- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/OhMyThreads.jl b/src/OhMyThreads.jl index a7015aa3..f7cfa58d 100644 --- a/src/OhMyThreads.jl +++ b/src/OhMyThreads.jl @@ -12,7 +12,7 @@ const Split = ChunkSplitters.Split const Consecutive = ChunkSplitters.Consecutive const RoundRobin = ChunkSplitters.RoundRobin export chunks, index_chunks -# export RoundRobin, Consecutive, Split # TODO: should we export this? +export RoundRobin, Consecutive, Split using TaskLocalValues: TaskLocalValues const TaskLocalValue = TaskLocalValues.TaskLocalValue diff --git a/src/schedulers.jl b/src/schedulers.jl index f04526e7..f4a622b6 100644 --- a/src/schedulers.jl +++ b/src/schedulers.jl @@ -57,10 +57,10 @@ with other multithreaded code. - `chunksize::Integer` (default not set) * Specifies the desired chunk size (instead of the number of chunks). * The options `chunksize` and `nchunks`/`ntasks` are **mutually exclusive** (only one may be a positive integer). -- `split::OhMyThreads.Split` (default `OhMyThreads.Consecutive()`): +- `split::Split` (default `Consecutive()`): * Determines how the collection is divided into chunks (if chunking=true). By default, each chunk consists of contiguous elements and order is maintained. * See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. - * Beware that for `split=OhMyThreads.RoundRobin()` the order of elements isn't maintained and a reducer function must not only be associative but also **commutative**! + * Beware that for `split=RoundRobin()` the order of elements isn't maintained and a reducer function must not only be associative but also **commutative**! - `chunking::Bool` (default `true`): * Controls whether input elements are grouped into chunks (`true`) or not (`false`). * For `chunking=false`, the arguments `nchunks`/`ntasks`, `chunksize`, and `split` are ignored and input elements are regarded as "chunks" as is. Hence, there will be one parallel task spawned per input element. Note that, depending on the input, this **might spawn many(!) tasks** and can be costly! @@ -75,7 +75,7 @@ struct DynamicScheduler{C <: ChunkingMode} <: Scheduler split::Split function DynamicScheduler(threadpool::Symbol, nchunks::Integer, chunksize::Integer, - split::Split; chunking::Bool = true) + split::Union{Split, Symbol}; chunking::Bool = true) if !(threadpool in (:default, :interactive)) throw(ArgumentError("threadpool must be either :default or :interactive")) end @@ -90,6 +90,13 @@ struct DynamicScheduler{C <: ChunkingMode} <: Scheduler end C = chunksize > 0 ? FixedSize : FixedCount end + if split isa Symbol + if split in (:consecutive, :batch) + split = Consecutive() + elseif split in (:roundrobin, :scatter) + split = RoundRobin() + end + end new{C}(threadpool, nchunks, chunksize, split) end end @@ -100,7 +107,7 @@ function DynamicScheduler(; ntasks::MaybeInteger = NotGiven(), # "alias" for nchunks chunksize::MaybeInteger = NotGiven(), chunking::Bool = true, - split::Split = Consecutive()) + split::Union{Split, Symbol} = Consecutive()) if !chunking nchunks = -1 chunksize = -1 @@ -152,17 +159,17 @@ Isn't well composable with other multithreaded code though. - `chunking::Bool` (default `true`): * Controls whether input elements are grouped into chunks (`true`) or not (`false`). * For `chunking=false`, the arguments `nchunks`/`ntasks`, `chunksize`, and `split` are ignored and input elements are regarded as "chunks" as is. Hence, there will be one parallel task spawned per input element. Note that, depending on the input, this **might spawn many(!) tasks** and can be costly! -- `split::OhMyThreads.Split` (default `OhMyThreads.Consecutive()`): +- `split::Split` (default `Consecutive()`): * Determines how the collection is divided into chunks. By default, each chunk consists of contiguous elements and order is maintained. * See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. - * Beware that for `split=OhMyThreads.RoundRobin()` the order of elements isn't maintained and a reducer function must not only be associative but also **commutative**! + * Beware that for `split=RoundRobin()` the order of elements isn't maintained and a reducer function must not only be associative but also **commutative**! """ struct StaticScheduler{C <: ChunkingMode} <: Scheduler nchunks::Int chunksize::Int split::Split - function StaticScheduler(nchunks::Integer, chunksize::Integer, split::Split; + function StaticScheduler(nchunks::Integer, chunksize::Integer, split::Union{Split, Symbol}; chunking::Bool = true) if !chunking C = NoChunking @@ -175,6 +182,13 @@ struct StaticScheduler{C <: ChunkingMode} <: Scheduler end C = chunksize > 0 ? FixedSize : FixedCount end + if split isa Symbol + if split in (:consecutive, :batch) + split = Consecutive() + elseif split in (:roundrobin, :scatter) + split = RoundRobin() + end + end new{C}(nchunks, chunksize, split) end end @@ -184,7 +198,7 @@ function StaticScheduler(; ntasks::MaybeInteger = NotGiven(), # "alias" for nchunks chunksize::MaybeInteger = NotGiven(), chunking::Bool = true, - split::Split = Consecutive()) + split::Union{Split, Symbol} = Consecutive()) if !chunking nchunks = -1 chunksize = -1 @@ -239,7 +253,7 @@ some additional overhead. - `chunksize::Integer` (default not set) * Specifies the desired chunk size (instead of the number of chunks). * The options `chunksize` and `nchunks` are **mutually exclusive** (only one may be a positive integer). -- `split::OhMyThreads.Split` (default `OhMyThreads.RoundRobin()`): +- `split::Split` (default `RoundRobin()`): * Determines how the collection is divided into chunks (if chunking=true). * See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. """ @@ -250,7 +264,7 @@ struct GreedyScheduler{C <: ChunkingMode} <: Scheduler split::Split function GreedyScheduler(ntasks::Int, nchunks::Integer, chunksize::Integer, - split::Split; chunking::Bool = false) + split::Union{Split, Symbol}; chunking::Bool = false) ntasks > 0 || throw(ArgumentError("ntasks must be a positive integer")) if !chunking C = NoChunking @@ -263,6 +277,13 @@ struct GreedyScheduler{C <: ChunkingMode} <: Scheduler end C = chunksize > 0 ? FixedSize : FixedCount end + if split isa Symbol + if split in (:consecutive, :batch) + split = Consecutive() + elseif split in (:roundrobin, :scatter) + split = RoundRobin() + end + end new{C}(ntasks, nchunks, chunksize, split) end end @@ -272,7 +293,7 @@ function GreedyScheduler(; nchunks::MaybeInteger = NotGiven(), chunksize::MaybeInteger = NotGiven(), chunking::Bool = false, - split::Split = RoundRobin()) + split::Union{Split, Symbol} = RoundRobin()) if isgiven(nchunks) || isgiven(chunksize) chunking = true end diff --git a/test/runtests.jl b/test/runtests.jl index 3540b6e2..3492a975 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,7 +23,7 @@ ChunkedGreedy(; kwargs...) = GreedyScheduler(; kwargs...) StaticScheduler, DynamicScheduler, GreedyScheduler, DynamicScheduler{OhMyThreads.Schedulers.NoChunking}, SerialScheduler, ChunkedGreedy) - @testset for split in (Consecutive(), RoundRobin()) + @testset for split in (Consecutive(), RoundRobin(), :consecutive, :roundrobin) for nchunks in (1, 2, 6) if sched == GreedyScheduler scheduler = sched(; ntasks = nchunks) @@ -36,7 +36,7 @@ ChunkedGreedy(; kwargs...) = GreedyScheduler(; kwargs...) end kwargs = (; scheduler) - if (split == RoundRobin() || + if (split in (RoundRobin(), :roundrobin) || sched ∈ (GreedyScheduler, ChunkedGreedy)) || op ∉ (vcat, *) # scatter and greedy only works for commutative operators! else @@ -46,7 +46,7 @@ ChunkedGreedy(; kwargs...) = GreedyScheduler(; kwargs...) @test treduce(op, f.(itrs...); init, kwargs...) ~ mapreduce_f_op_itr end - split == RoundRobin() && continue + split in (RoundRobin(), :roundrobin) && continue map_f_itr = map(f, itrs...) @test all(tmap(f, Any, itrs...; kwargs...) .~ map_f_itr) @test all(tcollect(Any, (f(x...) for x in collect(zip(itrs...))); kwargs...) .~ map_f_itr)