Skip to content

Commit

Permalink
prepare for ChunkSplitters 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
carstenbauer committed Sep 20, 2024
1 parent 8b657fc commit 26854cc
Show file tree
Hide file tree
Showing 13 changed files with 89 additions and 78 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
OhMyThreads.jl Changelog
=========================

Version 0.7.0
-------------
- ![BREAKING][badge-breaking] We now use ChunkSplitters version 3.0. The `split` keyword argument now requires a `::Split` rather than a `::Symbol`. Replace `:batch` by `BatchSplit()` and `:scatter` by `ScatterSplit()`. Moreover, the function `OhMyThreads.chunks` has been renamed to `OhMyThreads.chunk_indices`.
- ![Feature][badge-feature] We now re-export the functions `chunk_indices` and `chunk` from ChunkSplitters.jl.

Version 0.6.2
-------------
- ![Enhancement][badge-enhancement] Added API support for `enumerate(chunks(...))`. Best used in combination with `chunking=false`.
- ![Enhancement][badge-enhancement] Added API support for `enumerate(chunks(...))`. Best used in combination with `chunking=false`

Version 0.6.1
-------------
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
[compat]
Aqua = "0.8"
BangBang = "0.3.40, 0.4"
ChunkSplitters = "2.4"
ChunkSplitters = "3"
StableTasks = "0.1.5"
TaskLocalValues = "0.1"
Test = "1"
Expand Down
8 changes: 4 additions & 4 deletions docs/src/literate/falsesharing/falsesharing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ data = rand(1_000_000 * nthreads());
#
# A common, manual implementation of this idea might look like this:

using OhMyThreads: @spawn, chunks
using OhMyThreads: @spawn, chunk_indices

function parallel_sum_falsesharing(data; nchunks = nthreads())
psums = zeros(eltype(data), nchunks)
@sync for (c, idcs) in enumerate(chunks(data; n = nchunks))
@sync for (c, idcs) in enumerate(chunk_indices(data; n = nchunks))
@spawn begin
for i in idcs
psums[c] += data[i]
Expand Down Expand Up @@ -102,7 +102,7 @@ nthreads()

function parallel_sum_tasklocal(data; nchunks = nthreads())
psums = zeros(eltype(data), nchunks)
@sync for (c, idcs) in enumerate(chunks(data; n = nchunks))
@sync for (c, idcs) in enumerate(chunk_indices(data; n = nchunks))
@spawn begin
local s = zero(eltype(data))
for i in idcs
Expand Down Expand Up @@ -131,7 +131,7 @@ end
# using `map` and reusing the built-in (sequential) `sum` function on each parallel task:

function parallel_sum_map(data; nchunks = nthreads())
ts = map(chunks(data, n = nchunks)) do idcs
ts = map(chunk_indices(data, n = nchunks)) do idcs
@spawn @views sum(data[idcs])
end
return sum(fetch.(ts))
Expand Down
8 changes: 4 additions & 4 deletions docs/src/literate/falsesharing/falsesharing.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ catastrophic numerical errors due to potential rearrangements of terms in the su
A common, manual implementation of this idea might look like this:

````julia
using OhMyThreads: @spawn, chunks
using OhMyThreads: @spawn, chunk_indices

function parallel_sum_falsesharing(data; nchunks = nthreads())
psums = zeros(eltype(data), nchunks)
@sync for (c, idcs) in enumerate(chunks(data; n = nchunks))
@sync for (c, idcs) in enumerate(chunk_indices(data; n = nchunks))
@spawn begin
for i in idcs
psums[c] += data[i]
Expand Down Expand Up @@ -132,7 +132,7 @@ into `psums` (once!).
````julia
function parallel_sum_tasklocal(data; nchunks = nthreads())
psums = zeros(eltype(data), nchunks)
@sync for (c, idcs) in enumerate(chunks(data; n = nchunks))
@sync for (c, idcs) in enumerate(chunk_indices(data; n = nchunks))
@spawn begin
local s = zero(eltype(data))
for i in idcs
Expand Down Expand Up @@ -168,7 +168,7 @@ using `map` and reusing the built-in (sequential) `sum` function on each paralle

````julia
function parallel_sum_map(data; nchunks = nthreads())
ts = map(chunks(data, n = nchunks)) do idcs
ts = map(chunk_indices(data, n = nchunks)) do idcs
@spawn @views sum(data[idcs])
end
return sum(fetch.(ts))
Expand Down
8 changes: 4 additions & 4 deletions docs/src/literate/mc/mc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ using OhMyThreads: StaticScheduler

# ## Manual parallelization
#
# First, using the `chunks` function, we divide the iteration interval `1:N` into
# First, using the `chunk_indices` function, we divide the iteration interval `1:N` into
# `nthreads()` parts. Then, we apply a regular (sequential) `map` to spawn a Julia task
# per chunk. Each task will locally and independently perform a sequential Monte Carlo
# simulation. Finally, we fetch the results and compute the average estimate for $\pi$.

using OhMyThreads: @spawn, chunks
using OhMyThreads: @spawn, chunk_indices

function mc_parallel_manual(N; nchunks = nthreads())
tasks = map(chunks(1:N; n = nchunks)) do idcs
tasks = map(chunk_indices(1:N; n = nchunks)) do idcs
@spawn mc(length(idcs))
end
pi = sum(fetch, tasks) / nchunks
Expand All @@ -104,7 +104,7 @@ mc_parallel_manual(N)
# `mc(length(idcs))` is faster than the implicit task-local computation within
# `tmapreduce` (which itself is a `mapreduce`).

idcs = first(chunks(1:N; n = nthreads()))
idcs = first(chunk_indices(1:N; n = nthreads()))

@btime mapreduce($+, $idcs) do i
rand()^2 + rand()^2 < 1.0
Expand Down
8 changes: 4 additions & 4 deletions docs/src/literate/mc/mc.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,16 @@ using OhMyThreads: StaticScheduler

## Manual parallelization

First, using the `chunks` function, we divide the iteration interval `1:N` into
First, using the `chunk_indices` function, we divide the iteration interval `1:N` into
`nthreads()` parts. Then, we apply a regular (sequential) `map` to spawn a Julia task
per chunk. Each task will locally and independently perform a sequential Monte Carlo
simulation. Finally, we fetch the results and compute the average estimate for $\pi$.

````julia
using OhMyThreads: @spawn, chunks
using OhMyThreads: @spawn, chunk_indices

function mc_parallel_manual(N; nchunks = nthreads())
tasks = map(chunks(1:N; n = nchunks)) do idcs
tasks = map(chunk_indices(1:N; n = nchunks)) do idcs
@spawn mc(length(idcs))
end
pi = sum(fetch, tasks) / nchunks
Expand Down Expand Up @@ -151,7 +151,7 @@ It is faster than `mc_parallel` above because the task-local computation
`tmapreduce` (which itself is a `mapreduce`).

````julia
idcs = first(chunks(1:N; n = nthreads()))
idcs = first(chunk_indices(1:N; n = nthreads()))

@btime mapreduce($+, $idcs) do i
rand()^2 + rand()^2 < 1.0
Expand Down
4 changes: 2 additions & 2 deletions docs/src/literate/tls/tls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ res ≈ res_naive
# iterations (i.e. matrix pairs) for which this task is responsible.
# Before we learn how to do this more conveniently, let's implement this idea of a
# task-local temporary buffer (for each parallel task) manually.
using OhMyThreads: chunks, @spawn
using OhMyThreads: chunk_indices, @spawn
using Base.Threads: nthreads

function matmulsums_manual(As, Bs)
N = size(first(As), 1)
tasks = map(chunks(As; n = 2 * nthreads())) do idcs
tasks = map(chunk_indices(As; n = 2 * nthreads())) do idcs
@spawn begin
local C = Matrix{Float64}(undef, N, N)
map(idcs) do i
Expand Down
4 changes: 2 additions & 2 deletions docs/src/literate/tls/tls.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ Before we learn how to do this more conveniently, let's implement this idea of a
task-local temporary buffer (for each parallel task) manually.

````julia
using OhMyThreads: chunks, @spawn
using OhMyThreads: chunk_indices, @spawn
using Base.Threads: nthreads

function matmulsums_manual(As, Bs)
N = size(first(As), 1)
tasks = map(chunks(As; n = 2 * nthreads())) do idcs
tasks = map(chunk_indices(As; n = 2 * nthreads())) do idcs
@spawn begin
local C = Matrix{Float64}(undef, N, N)
map(idcs) do i
Expand Down
2 changes: 1 addition & 1 deletion docs/src/refs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ SerialScheduler
| `OhMyThreads.@spawnat` | see [StableTasks.jl](https://github.com/JuliaFolds2/StableTasks.jl) |
| `OhMyThreads.@fetch` | see [StableTasks.jl](https://github.com/JuliaFolds2/StableTasks.jl) |
| `OhMyThreads.@fetchfrom` | see [StableTasks.jl](https://github.com/JuliaFolds2/StableTasks.jl) |
| `OhMyThreads.chunks` | see [ChunkSplitters.jl](https://juliafolds2.github.io/ChunkSplitters.jl/dev/references/#ChunkSplitters.chunks) |
| `OhMyThreads.chunk_indices` | see [ChunkSplitters.jl](https://juliafolds2.github.io/ChunkSplitters.jl/dev/references/#ChunkSplitters.chunk_indices) |
| `OhMyThreads.TaskLocalValue` | see [TaskLocalValues.jl](https://github.com/vchuravy/TaskLocalValues.jl) |


Expand Down
7 changes: 6 additions & 1 deletion src/OhMyThreads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ for mac in Symbol.(["@spawn", "@spawnat", "@fetch", "@fetchfrom"])
end

using ChunkSplitters: ChunkSplitters
const chunks = ChunkSplitters.chunks
const chunk_indices = ChunkSplitters.chunk_indices
const chunk = ChunkSplitters.chunk
const Split = ChunkSplitters.Split
const BatchSplit = ChunkSplitters.BatchSplit
const ScatterSplit = ChunkSplitters.ScatterSplit
export ScatterSplit, BatchSplit, Split, chunk, chunk_indices

using TaskLocalValues: TaskLocalValues
const TaskLocalValue = TaskLocalValues.TaskLocalValue
Expand Down
52 changes: 27 additions & 25 deletions src/implementation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Implementation

import OhMyThreads: treduce, tmapreduce, treducemap, tforeach, tmap, tmap!, tcollect
using OhMyThreads: chunks, @spawn, @spawnat, WithTaskLocals, promise_task_local
using OhMyThreads: chunk_indices, @spawn, @spawnat, WithTaskLocals, promise_task_local
using OhMyThreads.Tools: nthtid
using OhMyThreads: Scheduler,
DynamicScheduler, StaticScheduler, GreedyScheduler,
Expand All @@ -14,30 +14,32 @@ using Base: @propagate_inbounds
using Base.Threads: nthreads, @threads
using BangBang: append!!
using ChunkSplitters: ChunkSplitters
using ChunkSplitters: BatchSplit
using ChunkSplitters.Internals: ChunksIterator, Enumerate

const MaybeScheduler = Union{NotGiven, Scheduler, Symbol}

include("macro_impl.jl")

function auto_disable_chunking_warning()
@warn("You passed in a `ChunkSplitters.Chunk` but also a scheduler that has "*
@warn("You passed in a `ChunksIterator` but also a scheduler that has "*
"chunking enabled. Will turn off internal chunking to proceed.\n"*
"To avoid this warning, turn off chunking (`chunking=false`).")
end

function _chunks(sched, arg)
function _chunk_indices(sched, arg)
C = chunking_mode(sched)
@assert C != NoChunking
if C == FixedCount
chunks(arg;
chunk_indices(arg;
n = sched.nchunks,
split = sched.split)::ChunkSplitters.Chunk{
typeof(arg), ChunkSplitters.FixedCount}
split = sched.split)::ChunksIterator{
typeof(arg), ChunkSplitters.Internals.FixedCount}
elseif C == FixedSize
chunks(arg;
chunk_indices(arg;
size = sched.chunksize,
split = sched.split)::ChunkSplitters.Chunk{
typeof(arg), ChunkSplitters.FixedSize}
split = sched.split)::ChunksIterator{
typeof(arg), ChunkSplitters.Internals.FixedSize}
end
end

Expand Down Expand Up @@ -85,7 +87,7 @@ function _tmapreduce(f,
(; threadpool) = scheduler
check_all_have_same_indices(Arrs)
if chunking_enabled(scheduler)
tasks = map(_chunks(scheduler, first(Arrs))) do inds
tasks = map(_chunk_indices(scheduler, first(Arrs))) do inds
args = map(A -> view(A, inds), Arrs)
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
Expand All @@ -102,10 +104,10 @@ function _tmapreduce(f,
end
end

# DynamicScheduler: ChunkSplitters.Chunk
# DynamicScheduler: ChunksIterator
function _tmapreduce(f,
op,
Arrs::Union{Tuple{ChunkSplitters.Chunk{T}}, Tuple{ChunkSplitters.Enumerate{T}}},
Arrs::Union{Tuple{ChunksIterator{T}}, Tuple{Enumerate{T}}},
::Type{OutputType},
scheduler::DynamicScheduler,
mapreduce_kwargs)::OutputType where {OutputType, T}
Expand All @@ -127,7 +129,7 @@ function _tmapreduce(f,
nt = nthreads()
check_all_have_same_indices(Arrs)
if chunking_enabled(scheduler)
tasks = map(enumerate(_chunks(scheduler, first(Arrs)))) do (c, inds)
tasks = map(enumerate(_chunk_indices(scheduler, first(Arrs)))) do (c, inds)
tid = @inbounds nthtid(mod1(c, nt))
args = map(A -> view(A, inds), Arrs)
# Note, calling `promise_task_local` here is only safe because we're assuming that
Expand All @@ -150,10 +152,10 @@ function _tmapreduce(f,
end
end

# StaticScheduler: ChunkSplitters.Chunk
# StaticScheduler: ChunksIterator
function _tmapreduce(f,
op,
Arrs::Tuple{ChunkSplitters.Chunk{T}}, # we don't support multiple chunks for now
Arrs::Tuple{ChunksIterator{T}}, # we don't support multiple chunks for now
::Type{OutputType},
scheduler::StaticScheduler,
mapreduce_kwargs)::OutputType where {OutputType, T}
Expand Down Expand Up @@ -235,7 +237,7 @@ function _tmapreduce(f,
throw(ArgumentError("SizeUnkown iterators in combination with a greedy scheduler and chunking are currently not supported."))
end
check_all_have_same_indices(Arrs)
chnks = _chunks(scheduler, first(Arrs))
chnks = _chunk_indices(scheduler, first(Arrs))
ntasks_desired = scheduler.ntasks
ntasks = min(length(chnks), ntasks_desired)

Expand Down Expand Up @@ -320,7 +322,7 @@ function tmap(f, ::Type{T}, A::AbstractArray, _Arrs::AbstractArray...; kwargs...
end

function tmap(f,
A::Union{AbstractArray, ChunkSplitters.Chunk, ChunkSplitters.Enumerate},
A::Union{AbstractArray, ChunksIterator, Enumerate},
_Arrs::AbstractArray...;
scheduler::MaybeScheduler = NotGiven(),
kwargs...)
Expand All @@ -330,10 +332,10 @@ function tmap(f,
error("Greedy scheduler isn't supported with `tmap` unless you provide an `OutputElementType` argument, since the greedy schedule requires a commutative reducing operator.")
end
if chunking_enabled(_scheduler) && hasfield(typeof(_scheduler), :split) &&
_scheduler.split != :batch
error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $_scheduler)")
_scheduler.split != BatchSplit()
error("Only `split == BatchSplit()` is supported because the parallel operation isn't commutative. (Scheduler: $_scheduler)")
end
if (A isa ChunkSplitters.Chunk || A isa ChunkSplitters.Enumerate) &&
if (A isa ChunksIterator || A isa Enumerate) &&
chunking_enabled(_scheduler)
auto_disable_chunking_warning()
if _scheduler isa DynamicScheduler
Expand Down Expand Up @@ -375,10 +377,10 @@ function _tmap(scheduler::DynamicScheduler{NoChunking},
reshape(v, size(A)...)
end

# w/o chunking (DynamicScheduler{NoChunking}): ChunkSplitters.Chunk
# w/o chunking (DynamicScheduler{NoChunking}): ChunksIterator
function _tmap(scheduler::DynamicScheduler{NoChunking},
f,
A::Union{ChunkSplitters.Chunk, ChunkSplitters.Enumerate},
A::Union{ChunksIterator, Enumerate},
_Arrs::AbstractArray...)
(; threadpool) = scheduler
tasks = map(A) do idcs
Expand All @@ -387,10 +389,10 @@ function _tmap(scheduler::DynamicScheduler{NoChunking},
map(fetch, tasks)
end

# w/o chunking (StaticScheduler{NoChunking}): ChunkSplitters.Chunk
# w/o chunking (StaticScheduler{NoChunking}): ChunksIterator
function _tmap(scheduler::StaticScheduler{NoChunking},
f,
A::ChunkSplitters.Chunk,
A::ChunksIterator,
_Arrs::AbstractArray...)
nt = nthreads()
tasks = map(enumerate(A)) do (c, idcs)
Expand Down Expand Up @@ -424,7 +426,7 @@ function _tmap(scheduler::Scheduler,
A::AbstractArray,
_Arrs::AbstractArray...)
Arrs = (A, _Arrs...)
idcs = collect(_chunks(scheduler, A))
idcs = collect(_chunk_indices(scheduler, A))
reduction_f = append!!
mapping_f = maybe_rewrap(f) do f
(inds) -> begin
Expand Down
Loading

0 comments on commit 26854cc

Please sign in to comment.