From 26854ccf322d775b48865318637cc0b9fb0c0731 Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Fri, 20 Sep 2024 17:47:38 +0200 Subject: [PATCH] prepare for ChunkSplitters 3.0 --- CHANGELOG.md | 7 ++- Project.toml | 2 +- .../src/literate/falsesharing/falsesharing.jl | 8 +-- .../src/literate/falsesharing/falsesharing.md | 8 +-- docs/src/literate/mc/mc.jl | 8 +-- docs/src/literate/mc/mc.md | 8 +-- docs/src/literate/tls/tls.jl | 4 +- docs/src/literate/tls/tls.md | 4 +- docs/src/refs/api.md | 2 +- src/OhMyThreads.jl | 7 ++- src/implementation.jl | 52 ++++++++++--------- src/schedulers.jl | 29 ++++++----- test/runtests.jl | 28 +++++----- 13 files changed, 89 insertions(+), 78 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d536c69..0311c593 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,14 @@ OhMyThreads.jl Changelog ========================= +Version 0.7.0 +------------- +- ![BREAKING][badge-breaking] We now use ChunkSplitters version 3.0. The `split` keyword argument now requires a `::Split` rather than a `::Symbol`. Replace `:batch` by `BatchSplit()` and `:scatter` by `ScatterSplit()`. Moreover, the function `OhMyThreads.chunks` has been renamed to `OhMyThreads.chunk_indices`. +- ![Feature][badge-feature] We now re-export the functions `chunk_indices` and `chunk` from ChunkSplitters.jl. + Version 0.6.2 ------------- -- ![Enhancement][badge-enhancement] Added API support for `enumerate(chunks(...))`. Best used in combination with `chunking=false`. +- ![Enhancement][badge-enhancement] Added API support for `enumerate(chunks(...))`. Best used in combination with `chunking=false` Version 0.6.1 ------------- diff --git a/Project.toml b/Project.toml index 6fb95d1b..2a3dc26e 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34" [compat] Aqua = "0.8" BangBang = "0.3.40, 0.4" -ChunkSplitters = "2.4" +ChunkSplitters = "3" StableTasks = "0.1.5" TaskLocalValues = "0.1" Test = "1" diff --git a/docs/src/literate/falsesharing/falsesharing.jl b/docs/src/literate/falsesharing/falsesharing.jl index 6c395923..11264ce8 100644 --- a/docs/src/literate/falsesharing/falsesharing.jl +++ b/docs/src/literate/falsesharing/falsesharing.jl @@ -30,11 +30,11 @@ data = rand(1_000_000 * nthreads()); # # A common, manual implementation of this idea might look like this: -using OhMyThreads: @spawn, chunks +using OhMyThreads: @spawn, chunk_indices function parallel_sum_falsesharing(data; nchunks = nthreads()) psums = zeros(eltype(data), nchunks) - @sync for (c, idcs) in enumerate(chunks(data; n = nchunks)) + @sync for (c, idcs) in enumerate(chunk_indices(data; n = nchunks)) @spawn begin for i in idcs psums[c] += data[i] @@ -102,7 +102,7 @@ nthreads() function parallel_sum_tasklocal(data; nchunks = nthreads()) psums = zeros(eltype(data), nchunks) - @sync for (c, idcs) in enumerate(chunks(data; n = nchunks)) + @sync for (c, idcs) in enumerate(chunk_indices(data; n = nchunks)) @spawn begin local s = zero(eltype(data)) for i in idcs @@ -131,7 +131,7 @@ end # using `map` and reusing the built-in (sequential) `sum` function on each parallel task: function parallel_sum_map(data; nchunks = nthreads()) - ts = map(chunks(data, n = nchunks)) do idcs + ts = map(chunk_indices(data, n = nchunks)) do idcs @spawn @views sum(data[idcs]) end return sum(fetch.(ts)) diff --git a/docs/src/literate/falsesharing/falsesharing.md b/docs/src/literate/falsesharing/falsesharing.md index 94655365..068aeae4 100644 --- a/docs/src/literate/falsesharing/falsesharing.md +++ b/docs/src/literate/falsesharing/falsesharing.md @@ -39,11 +39,11 @@ catastrophic numerical errors due to potential rearrangements of terms in the su A common, manual implementation of this idea might look like this: ````julia -using OhMyThreads: @spawn, chunks +using OhMyThreads: @spawn, chunk_indices function parallel_sum_falsesharing(data; nchunks = nthreads()) psums = zeros(eltype(data), nchunks) - @sync for (c, idcs) in enumerate(chunks(data; n = nchunks)) + @sync for (c, idcs) in enumerate(chunk_indices(data; n = nchunks)) @spawn begin for i in idcs psums[c] += data[i] @@ -132,7 +132,7 @@ into `psums` (once!). ````julia function parallel_sum_tasklocal(data; nchunks = nthreads()) psums = zeros(eltype(data), nchunks) - @sync for (c, idcs) in enumerate(chunks(data; n = nchunks)) + @sync for (c, idcs) in enumerate(chunk_indices(data; n = nchunks)) @spawn begin local s = zero(eltype(data)) for i in idcs @@ -168,7 +168,7 @@ using `map` and reusing the built-in (sequential) `sum` function on each paralle ````julia function parallel_sum_map(data; nchunks = nthreads()) - ts = map(chunks(data, n = nchunks)) do idcs + ts = map(chunk_indices(data, n = nchunks)) do idcs @spawn @views sum(data[idcs]) end return sum(fetch.(ts)) diff --git a/docs/src/literate/mc/mc.jl b/docs/src/literate/mc/mc.jl index 6a9abd37..30816c70 100644 --- a/docs/src/literate/mc/mc.jl +++ b/docs/src/literate/mc/mc.jl @@ -79,15 +79,15 @@ using OhMyThreads: StaticScheduler # ## Manual parallelization # -# First, using the `chunks` function, we divide the iteration interval `1:N` into +# First, using the `chunk_indices` function, we divide the iteration interval `1:N` into # `nthreads()` parts. Then, we apply a regular (sequential) `map` to spawn a Julia task # per chunk. Each task will locally and independently perform a sequential Monte Carlo # simulation. Finally, we fetch the results and compute the average estimate for $\pi$. -using OhMyThreads: @spawn, chunks +using OhMyThreads: @spawn, chunk_indices function mc_parallel_manual(N; nchunks = nthreads()) - tasks = map(chunks(1:N; n = nchunks)) do idcs + tasks = map(chunk_indices(1:N; n = nchunks)) do idcs @spawn mc(length(idcs)) end pi = sum(fetch, tasks) / nchunks @@ -104,7 +104,7 @@ mc_parallel_manual(N) # `mc(length(idcs))` is faster than the implicit task-local computation within # `tmapreduce` (which itself is a `mapreduce`). -idcs = first(chunks(1:N; n = nthreads())) +idcs = first(chunk_indices(1:N; n = nthreads())) @btime mapreduce($+, $idcs) do i rand()^2 + rand()^2 < 1.0 diff --git a/docs/src/literate/mc/mc.md b/docs/src/literate/mc/mc.md index 44506abb..27bb4d32 100644 --- a/docs/src/literate/mc/mc.md +++ b/docs/src/literate/mc/mc.md @@ -112,16 +112,16 @@ using OhMyThreads: StaticScheduler ## Manual parallelization -First, using the `chunks` function, we divide the iteration interval `1:N` into +First, using the `chunk_indices` function, we divide the iteration interval `1:N` into `nthreads()` parts. Then, we apply a regular (sequential) `map` to spawn a Julia task per chunk. Each task will locally and independently perform a sequential Monte Carlo simulation. Finally, we fetch the results and compute the average estimate for $\pi$. ````julia -using OhMyThreads: @spawn, chunks +using OhMyThreads: @spawn, chunk_indices function mc_parallel_manual(N; nchunks = nthreads()) - tasks = map(chunks(1:N; n = nchunks)) do idcs + tasks = map(chunk_indices(1:N; n = nchunks)) do idcs @spawn mc(length(idcs)) end pi = sum(fetch, tasks) / nchunks @@ -151,7 +151,7 @@ It is faster than `mc_parallel` above because the task-local computation `tmapreduce` (which itself is a `mapreduce`). ````julia -idcs = first(chunks(1:N; n = nthreads())) +idcs = first(chunk_indices(1:N; n = nthreads())) @btime mapreduce($+, $idcs) do i rand()^2 + rand()^2 < 1.0 diff --git a/docs/src/literate/tls/tls.jl b/docs/src/literate/tls/tls.jl index 20c77adc..e84139d3 100644 --- a/docs/src/literate/tls/tls.jl +++ b/docs/src/literate/tls/tls.jl @@ -102,12 +102,12 @@ res ≈ res_naive # iterations (i.e. matrix pairs) for which this task is responsible. # Before we learn how to do this more conveniently, let's implement this idea of a # task-local temporary buffer (for each parallel task) manually. -using OhMyThreads: chunks, @spawn +using OhMyThreads: chunk_indices, @spawn using Base.Threads: nthreads function matmulsums_manual(As, Bs) N = size(first(As), 1) - tasks = map(chunks(As; n = 2 * nthreads())) do idcs + tasks = map(chunk_indices(As; n = 2 * nthreads())) do idcs @spawn begin local C = Matrix{Float64}(undef, N, N) map(idcs) do i diff --git a/docs/src/literate/tls/tls.md b/docs/src/literate/tls/tls.md index 13667605..c2a39911 100644 --- a/docs/src/literate/tls/tls.md +++ b/docs/src/literate/tls/tls.md @@ -140,12 +140,12 @@ Before we learn how to do this more conveniently, let's implement this idea of a task-local temporary buffer (for each parallel task) manually. ````julia -using OhMyThreads: chunks, @spawn +using OhMyThreads: chunk_indices, @spawn using Base.Threads: nthreads function matmulsums_manual(As, Bs) N = size(first(As), 1) - tasks = map(chunks(As; n = 2 * nthreads())) do idcs + tasks = map(chunk_indices(As; n = 2 * nthreads())) do idcs @spawn begin local C = Matrix{Float64}(undef, N, N) map(idcs) do i diff --git a/docs/src/refs/api.md b/docs/src/refs/api.md index 3f9134db..af0ac56b 100644 --- a/docs/src/refs/api.md +++ b/docs/src/refs/api.md @@ -45,7 +45,7 @@ SerialScheduler | `OhMyThreads.@spawnat` | see [StableTasks.jl](https://github.com/JuliaFolds2/StableTasks.jl) | | `OhMyThreads.@fetch` | see [StableTasks.jl](https://github.com/JuliaFolds2/StableTasks.jl) | | `OhMyThreads.@fetchfrom` | see [StableTasks.jl](https://github.com/JuliaFolds2/StableTasks.jl) | -| `OhMyThreads.chunks` | see [ChunkSplitters.jl](https://juliafolds2.github.io/ChunkSplitters.jl/dev/references/#ChunkSplitters.chunks) | +| `OhMyThreads.chunk_indices` | see [ChunkSplitters.jl](https://juliafolds2.github.io/ChunkSplitters.jl/dev/references/#ChunkSplitters.chunk_indices) | | `OhMyThreads.TaskLocalValue` | see [TaskLocalValues.jl](https://github.com/vchuravy/TaskLocalValues.jl) | diff --git a/src/OhMyThreads.jl b/src/OhMyThreads.jl index 2dc90a60..a4981033 100644 --- a/src/OhMyThreads.jl +++ b/src/OhMyThreads.jl @@ -6,7 +6,12 @@ for mac in Symbol.(["@spawn", "@spawnat", "@fetch", "@fetchfrom"]) end using ChunkSplitters: ChunkSplitters -const chunks = ChunkSplitters.chunks +const chunk_indices = ChunkSplitters.chunk_indices +const chunk = ChunkSplitters.chunk +const Split = ChunkSplitters.Split +const BatchSplit = ChunkSplitters.BatchSplit +const ScatterSplit = ChunkSplitters.ScatterSplit +export ScatterSplit, BatchSplit, Split, chunk, chunk_indices using TaskLocalValues: TaskLocalValues const TaskLocalValue = TaskLocalValues.TaskLocalValue diff --git a/src/implementation.jl b/src/implementation.jl index cd7227cf..5113fc0f 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -1,7 +1,7 @@ module Implementation import OhMyThreads: treduce, tmapreduce, treducemap, tforeach, tmap, tmap!, tcollect -using OhMyThreads: chunks, @spawn, @spawnat, WithTaskLocals, promise_task_local +using OhMyThreads: chunk_indices, @spawn, @spawnat, WithTaskLocals, promise_task_local using OhMyThreads.Tools: nthtid using OhMyThreads: Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler, @@ -14,30 +14,32 @@ using Base: @propagate_inbounds using Base.Threads: nthreads, @threads using BangBang: append!! using ChunkSplitters: ChunkSplitters +using ChunkSplitters: BatchSplit +using ChunkSplitters.Internals: ChunksIterator, Enumerate const MaybeScheduler = Union{NotGiven, Scheduler, Symbol} include("macro_impl.jl") function auto_disable_chunking_warning() - @warn("You passed in a `ChunkSplitters.Chunk` but also a scheduler that has "* + @warn("You passed in a `ChunksIterator` but also a scheduler that has "* "chunking enabled. Will turn off internal chunking to proceed.\n"* "To avoid this warning, turn off chunking (`chunking=false`).") end -function _chunks(sched, arg) +function _chunk_indices(sched, arg) C = chunking_mode(sched) @assert C != NoChunking if C == FixedCount - chunks(arg; + chunk_indices(arg; n = sched.nchunks, - split = sched.split)::ChunkSplitters.Chunk{ - typeof(arg), ChunkSplitters.FixedCount} + split = sched.split)::ChunksIterator{ + typeof(arg), ChunkSplitters.Internals.FixedCount} elseif C == FixedSize - chunks(arg; + chunk_indices(arg; size = sched.chunksize, - split = sched.split)::ChunkSplitters.Chunk{ - typeof(arg), ChunkSplitters.FixedSize} + split = sched.split)::ChunksIterator{ + typeof(arg), ChunkSplitters.Internals.FixedSize} end end @@ -85,7 +87,7 @@ function _tmapreduce(f, (; threadpool) = scheduler check_all_have_same_indices(Arrs) if chunking_enabled(scheduler) - tasks = map(_chunks(scheduler, first(Arrs))) do inds + tasks = map(_chunk_indices(scheduler, first(Arrs))) do inds args = map(A -> view(A, inds), Arrs) # Note, calling `promise_task_local` here is only safe because we're assuming that # Base.mapreduce isn't going to magically try to do multithreading on us... @@ -102,10 +104,10 @@ function _tmapreduce(f, end end -# DynamicScheduler: ChunkSplitters.Chunk +# DynamicScheduler: ChunksIterator function _tmapreduce(f, op, - Arrs::Union{Tuple{ChunkSplitters.Chunk{T}}, Tuple{ChunkSplitters.Enumerate{T}}}, + Arrs::Union{Tuple{ChunksIterator{T}}, Tuple{Enumerate{T}}}, ::Type{OutputType}, scheduler::DynamicScheduler, mapreduce_kwargs)::OutputType where {OutputType, T} @@ -127,7 +129,7 @@ function _tmapreduce(f, nt = nthreads() check_all_have_same_indices(Arrs) if chunking_enabled(scheduler) - tasks = map(enumerate(_chunks(scheduler, first(Arrs)))) do (c, inds) + tasks = map(enumerate(_chunk_indices(scheduler, first(Arrs)))) do (c, inds) tid = @inbounds nthtid(mod1(c, nt)) args = map(A -> view(A, inds), Arrs) # Note, calling `promise_task_local` here is only safe because we're assuming that @@ -150,10 +152,10 @@ function _tmapreduce(f, end end -# StaticScheduler: ChunkSplitters.Chunk +# StaticScheduler: ChunksIterator function _tmapreduce(f, op, - Arrs::Tuple{ChunkSplitters.Chunk{T}}, # we don't support multiple chunks for now + Arrs::Tuple{ChunksIterator{T}}, # we don't support multiple chunks for now ::Type{OutputType}, scheduler::StaticScheduler, mapreduce_kwargs)::OutputType where {OutputType, T} @@ -235,7 +237,7 @@ function _tmapreduce(f, throw(ArgumentError("SizeUnkown iterators in combination with a greedy scheduler and chunking are currently not supported.")) end check_all_have_same_indices(Arrs) - chnks = _chunks(scheduler, first(Arrs)) + chnks = _chunk_indices(scheduler, first(Arrs)) ntasks_desired = scheduler.ntasks ntasks = min(length(chnks), ntasks_desired) @@ -320,7 +322,7 @@ function tmap(f, ::Type{T}, A::AbstractArray, _Arrs::AbstractArray...; kwargs... end function tmap(f, - A::Union{AbstractArray, ChunkSplitters.Chunk, ChunkSplitters.Enumerate}, + A::Union{AbstractArray, ChunksIterator, Enumerate}, _Arrs::AbstractArray...; scheduler::MaybeScheduler = NotGiven(), kwargs...) @@ -330,10 +332,10 @@ function tmap(f, error("Greedy scheduler isn't supported with `tmap` unless you provide an `OutputElementType` argument, since the greedy schedule requires a commutative reducing operator.") end if chunking_enabled(_scheduler) && hasfield(typeof(_scheduler), :split) && - _scheduler.split != :batch - error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $_scheduler)") + _scheduler.split != BatchSplit() + error("Only `split == BatchSplit()` is supported because the parallel operation isn't commutative. (Scheduler: $_scheduler)") end - if (A isa ChunkSplitters.Chunk || A isa ChunkSplitters.Enumerate) && + if (A isa ChunksIterator || A isa Enumerate) && chunking_enabled(_scheduler) auto_disable_chunking_warning() if _scheduler isa DynamicScheduler @@ -375,10 +377,10 @@ function _tmap(scheduler::DynamicScheduler{NoChunking}, reshape(v, size(A)...) end -# w/o chunking (DynamicScheduler{NoChunking}): ChunkSplitters.Chunk +# w/o chunking (DynamicScheduler{NoChunking}): ChunksIterator function _tmap(scheduler::DynamicScheduler{NoChunking}, f, - A::Union{ChunkSplitters.Chunk, ChunkSplitters.Enumerate}, + A::Union{ChunksIterator, Enumerate}, _Arrs::AbstractArray...) (; threadpool) = scheduler tasks = map(A) do idcs @@ -387,10 +389,10 @@ function _tmap(scheduler::DynamicScheduler{NoChunking}, map(fetch, tasks) end -# w/o chunking (StaticScheduler{NoChunking}): ChunkSplitters.Chunk +# w/o chunking (StaticScheduler{NoChunking}): ChunksIterator function _tmap(scheduler::StaticScheduler{NoChunking}, f, - A::ChunkSplitters.Chunk, + A::ChunksIterator, _Arrs::AbstractArray...) nt = nthreads() tasks = map(enumerate(A)) do (c, idcs) @@ -424,7 +426,7 @@ function _tmap(scheduler::Scheduler, A::AbstractArray, _Arrs::AbstractArray...) Arrs = (A, _Arrs...) - idcs = collect(_chunks(scheduler, A)) + idcs = collect(_chunk_indices(scheduler, A)) reduction_f = append!! mapping_f = maybe_rewrap(f) do f (inds) -> begin diff --git a/src/schedulers.jl b/src/schedulers.jl index b77feb8b..6795e9a2 100644 --- a/src/schedulers.jl +++ b/src/schedulers.jl @@ -1,6 +1,7 @@ module Schedulers using Base.Threads: nthreads +using ChunkSplitters: Split, BatchSplit, ScatterSplit # Used to indicate that a keyword argument has not been set by the user. # We don't use Nothing because nothing maybe sometimes be a valid user input (e.g. for init) @@ -56,10 +57,10 @@ with other multithreaded code. - `chunksize::Integer` (default not set) * Specifies the desired chunk size (instead of the number of chunks). * The options `chunksize` and `nchunks`/`ntasks` are **mutually exclusive** (only one may be a positive integer). -- `split::Symbol` (default `:batch`): +- `split::Split` (default `BatchSplit()`): * Determines how the collection is divided into chunks (if chunking=true). By default, each chunk consists of contiguous elements and order is maintained. * See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. - * Beware that for `split=:scatter` the order of elements isn't maintained and a reducer function must not only be associative but also **commutative**! + * Beware that for `split=ScatterSplit()` the order of elements isn't maintained and a reducer function must not only be associative but also **commutative**! - `chunking::Bool` (default `true`): * Controls whether input elements are grouped into chunks (`true`) or not (`false`). * For `chunking=false`, the arguments `nchunks`/`ntasks`, `chunksize`, and `split` are ignored and input elements are regarded as "chunks" as is. Hence, there will be one parallel task spawned per input element. Note that, depending on the input, this **might spawn many(!) tasks** and can be costly! @@ -71,10 +72,10 @@ struct DynamicScheduler{C <: ChunkingMode} <: Scheduler threadpool::Symbol nchunks::Int chunksize::Int - split::Symbol + split::Split function DynamicScheduler(threadpool::Symbol, nchunks::Integer, chunksize::Integer, - split::Symbol; chunking::Bool = true) + split::Split; chunking::Bool = true) if !(threadpool in (:default, :interactive)) throw(ArgumentError("threadpool must be either :default or :interactive")) end @@ -99,7 +100,7 @@ function DynamicScheduler(; ntasks::MaybeInteger = NotGiven(), # "alias" for nchunks chunksize::MaybeInteger = NotGiven(), chunking::Bool = true, - split::Symbol = :batch) + split::Split = BatchSplit()) if !chunking nchunks = -1 chunksize = -1 @@ -151,17 +152,17 @@ Isn't well composable with other multithreaded code though. - `chunking::Bool` (default `true`): * Controls whether input elements are grouped into chunks (`true`) or not (`false`). * For `chunking=false`, the arguments `nchunks`/`ntasks`, `chunksize`, and `split` are ignored and input elements are regarded as "chunks" as is. Hence, there will be one parallel task spawned per input element. Note that, depending on the input, this **might spawn many(!) tasks** and can be costly! -- `split::Symbol` (default `:batch`): +- `split::Split` (default `BatchSplit()`): * Determines how the collection is divided into chunks. By default, each chunk consists of contiguous elements and order is maintained. * See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. - * Beware that for `split=:scatter` the order of elements isn't maintained and a reducer function must not only be associative but also **commutative**! + * Beware that for `split=ScatterSplit()` the order of elements isn't maintained and a reducer function must not only be associative but also **commutative**! """ struct StaticScheduler{C <: ChunkingMode} <: Scheduler nchunks::Int chunksize::Int - split::Symbol + split::Split - function StaticScheduler(nchunks::Integer, chunksize::Integer, split::Symbol; + function StaticScheduler(nchunks::Integer, chunksize::Integer, split::Split; chunking::Bool = true) if !chunking C = NoChunking @@ -183,7 +184,7 @@ function StaticScheduler(; ntasks::MaybeInteger = NotGiven(), # "alias" for nchunks chunksize::MaybeInteger = NotGiven(), chunking::Bool = true, - split::Symbol = :batch) + split::Split = BatchSplit()) if !chunking nchunks = -1 chunksize = -1 @@ -238,7 +239,7 @@ some additional overhead. - `chunksize::Integer` (default not set) * Specifies the desired chunk size (instead of the number of chunks). * The options `chunksize` and `nchunks` are **mutually exclusive** (only one may be a positive integer). -- `split::Symbol` (default `:scatter`): +- `split::Split` (default `ScatterSplit()`): * Determines how the collection is divided into chunks (if chunking=true). * See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. """ @@ -246,10 +247,10 @@ struct GreedyScheduler{C <: ChunkingMode} <: Scheduler ntasks::Int nchunks::Int chunksize::Int - split::Symbol + split::Split function GreedyScheduler(ntasks::Int, nchunks::Integer, chunksize::Integer, - split::Symbol; chunking::Bool = false) + split::Split; chunking::Bool = false) ntasks > 0 || throw(ArgumentError("ntasks must be a positive integer")) if !chunking C = NoChunking @@ -271,7 +272,7 @@ function GreedyScheduler(; nchunks::MaybeInteger = NotGiven(), chunksize::MaybeInteger = NotGiven(), chunking::Bool = false, - split::Symbol = :scatter) + split::Split = ScatterSplit()) if isgiven(nchunks) || isgiven(chunksize) chunking = true end diff --git a/test/runtests.jl b/test/runtests.jl index 05f0be9a..ff973837 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Test, OhMyThreads using OhMyThreads: TaskLocalValue, WithTaskLocals, @fetch, promise_task_local +using OhMyThreads: BatchSplit, ScatterSplit using OhMyThreads.Experimental: @barrier include("Aqua.jl") @@ -22,7 +23,7 @@ ChunkedGreedy(; kwargs...) = GreedyScheduler(; kwargs...) StaticScheduler, DynamicScheduler, GreedyScheduler, DynamicScheduler{OhMyThreads.Schedulers.NoChunking}, SerialScheduler, ChunkedGreedy) - @testset for split in (:batch, :scatter) + @testset for split in (BatchSplit(), ScatterSplit()) for nchunks in (1, 2, 6) if sched == GreedyScheduler scheduler = sched(; ntasks = nchunks) @@ -35,7 +36,7 @@ ChunkedGreedy(; kwargs...) = GreedyScheduler(; kwargs...) end kwargs = (; scheduler) - if (split == :scatter || + if (split == ScatterSplit() || sched ∈ (GreedyScheduler, ChunkedGreedy)) || op ∉ (vcat, *) # scatter and greedy only works for commutative operators! else @@ -45,7 +46,7 @@ ChunkedGreedy(; kwargs...) = GreedyScheduler(; kwargs...) @test treduce(op, f.(itrs...); init, kwargs...) ~ mapreduce_f_op_itr end - split == :scatter && continue + split == ScatterSplit() && continue map_f_itr = map(f, itrs...) @test all(tmap(f, Any, itrs...; kwargs...) .~ map_f_itr) @test all(tcollect(Any, (f(x...) for x in collect(zip(itrs...))); kwargs...) .~ map_f_itr) @@ -71,7 +72,7 @@ end; @testset "ChunkSplitters.Chunk" begin x = rand(100) - chnks = OhMyThreads.chunks(x; n = Threads.nthreads()) + chnks = OhMyThreads.chunk_indices(x; n = Threads.nthreads()) for scheduler in ( DynamicScheduler(; chunking = false), StaticScheduler(; chunking = false)) @testset "$scheduler" begin @@ -86,13 +87,13 @@ end; # enumerate(chunks) data = 1:100 - @test tmapreduce(+, enumerate(OhMyThreads.chunks(data; n=5)); chunking=false) do (i, idcs) + @test tmapreduce(+, enumerate(OhMyThreads.chunk_indices(data; n=5)); chunking=false) do (i, idcs) [i, sum(@view(data[idcs]))] end == [sum(1:5), sum(data)] - @test tmapreduce(+, enumerate(OhMyThreads.chunks(data; size=5)); chunking=false) do (i, idcs) + @test tmapreduce(+, enumerate(OhMyThreads.chunk_indices(data; size=5)); chunking=false) do (i, idcs) [i, sum(@view(data[idcs]))] end == [sum(1:20), sum(data)] - @test tmap(enumerate(OhMyThreads.chunks(data; n=5)); chunking=false) do (i, idcs) + @test tmap(enumerate(OhMyThreads.chunk_indices(data; n=5)); chunking=false) do (i, idcs) [i, idcs] end == [[1, 1:20], [2, 21:40], [3, 41:60], [4, 61:80], [5, 81:100]] end; @@ -261,16 +262,16 @@ end; # enumerate(chunks) data = collect(1:100) - @test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunks(data; n=5)) + @test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunk_indices(data; n=5)) @set reducer = + @set chunking = false [i, sum(@view(data[idcs]))] end) == [sum(1:5), sum(data)] - @test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunks(data; size=5)) + @test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunk_indices(data; size=5)) @set reducer = + [i, sum(@view(data[idcs]))] end) == [sum(1:20), sum(data)] - @test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunks(1:100; n=5)) + @test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunk_indices(1:100; n=5)) @set chunking=false @set collect=true [i, idcs] @@ -331,9 +332,6 @@ end; @test OhMyThreads.Schedulers.chunking_mode(sched(; nchunks = 2, chunksize = 4, chunking = false)) == OhMyThreads.Schedulers.NoChunking - @test OhMyThreads.Schedulers.chunking_mode(sched(; - nchunks = -2, chunksize = -4, split = :whatever, chunking = false)) == - OhMyThreads.Schedulers.NoChunking @test OhMyThreads.Schedulers.chunking_enabled(sched(; chunksize = 2)) == true @test OhMyThreads.Schedulers.chunking_enabled(sched(; nchunks = 2)) == true @test OhMyThreads.Schedulers.chunking_enabled(sched(; @@ -360,7 +358,7 @@ end; # scheduler not given @test tmapreduce(sin, +, 1:10000; ntasks = 2) ≈ res_tmr @test tmapreduce(sin, +, 1:10000; nchunks = 2) ≈ res_tmr - @test tmapreduce(sin, +, 1:10000; split = :scatter) ≈ res_tmr + @test tmapreduce(sin, +, 1:10000; split = ScatterSplit()) ≈ res_tmr @test tmapreduce(sin, +, 1:10000; chunksize = 2) ≈ res_tmr @test tmapreduce(sin, +, 1:10000; chunking = false) ≈ res_tmr @@ -371,7 +369,7 @@ end; @test_throws ArgumentError tmapreduce( sin, +, 1:10000; chunksize = 2, scheduler = DynamicScheduler()) @test_throws ArgumentError tmapreduce( - sin, +, 1:10000; split = :scatter, scheduler = StaticScheduler()) + sin, +, 1:10000; split = ScatterSplit(), scheduler = StaticScheduler()) @test_throws ArgumentError tmapreduce( sin, +, 1:10000; ntasks = 3, scheduler = SerialScheduler())