Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for enumerate(chunks(...)) #117

Merged
merged 3 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
OhMyThreads.jl Changelog
=========================

Version 0.6.2
-------------
- ![Enhancement][badge-enhancement] Added API support for `enumerate(chunks(...))`. Best used in combination with `chunking=false`.

Version 0.6.1
-------------

Version 0.6.0
-------------
- ![BREAKING][badge-breaking] Drop support for Julia < 1.10.
Expand Down
11 changes: 6 additions & 5 deletions src/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ end
# DynamicScheduler: ChunkSplitters.Chunk
function _tmapreduce(f,
op,
Arrs::Tuple{ChunkSplitters.Chunk{T}}, # we don't support multiple chunks for now
Arrs::Union{Tuple{ChunkSplitters.Chunk{T}}, Tuple{ChunkSplitters.Enumerate{T}}},
::Type{OutputType},
scheduler::DynamicScheduler,
mapreduce_kwargs)::OutputType where {OutputType, T}
(; nchunks, split, threadpool) = scheduler
(; threadpool) = scheduler
chunking_enabled(scheduler) && auto_disable_chunking_warning()
tasks = map(only(Arrs)) do idcs
@spawn threadpool promise_task_local(f)(idcs)
Expand Down Expand Up @@ -320,7 +320,7 @@ function tmap(f, ::Type{T}, A::AbstractArray, _Arrs::AbstractArray...; kwargs...
end

function tmap(f,
A::Union{AbstractArray, ChunkSplitters.Chunk},
A::Union{AbstractArray, ChunkSplitters.Chunk, ChunkSplitters.Enumerate},
_Arrs::AbstractArray...;
scheduler::MaybeScheduler = NotGiven(),
kwargs...)
Expand All @@ -333,7 +333,8 @@ function tmap(f,
_scheduler.split != :batch
error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $_scheduler)")
end
if A isa ChunkSplitters.Chunk && chunking_enabled(_scheduler)
if (A isa ChunkSplitters.Chunk || A isa ChunkSplitters.Enumerate) &&
chunking_enabled(_scheduler)
auto_disable_chunking_warning()
if _scheduler isa DynamicScheduler
_scheduler = DynamicScheduler(;
Expand Down Expand Up @@ -377,7 +378,7 @@ end
# w/o chunking (DynamicScheduler{NoChunking}): ChunkSplitters.Chunk
function _tmap(scheduler::DynamicScheduler{NoChunking},
f,
A::ChunkSplitters.Chunk,
A::Union{ChunkSplitters.Chunk, ChunkSplitters.Enumerate},
_Arrs::AbstractArray...)
(; threadpool) = scheduler
tasks = map(A) do idcs
Expand Down
29 changes: 29 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ end;
@test isnothing(tforeach(x -> sin.(x), chnks; scheduler))
end
end

# enumerate(chunks)
data = 1:100
@test tmapreduce(+, enumerate(OhMyThreads.chunks(data; n=5)); chunking=false) do (i, idcs)
[i, sum(@view(data[idcs]))]
end == [sum(1:5), sum(data)]
@test tmapreduce(+, enumerate(OhMyThreads.chunks(data; size=5)); chunking=false) do (i, idcs)
[i, sum(@view(data[idcs]))]
end == [sum(1:20), sum(data)]
@test tmap(enumerate(OhMyThreads.chunks(data; n=5)); chunking=false) do (i, idcs)
[i, idcs]
end == [[1, 1:20], [2, 21:40], [3, 41:60], [4, 61:80], [5, 81:100]]
end;

@testset "macro API" begin
Expand Down Expand Up @@ -246,6 +258,23 @@ end;
@set reducer = +
C.x
end) == 10 * var

# enumerate(chunks)
data = collect(1:100)
@test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunks(data; n=5))
@set reducer = +
@set chunking = false
[i, sum(@view(data[idcs]))]
end) == [sum(1:5), sum(data)]
@test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunks(data; size=5))
@set reducer = +
[i, sum(@view(data[idcs]))]
end) == [sum(1:20), sum(data)]
@test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunks(1:100; n=5))
@set chunking=false
@set collect=true
[i, idcs]
end) == [[1, 1:20], [2, 21:40], [3, 41:60], [4, 61:80], [5, 81:100]]
end;

@testset "WithTaskLocals" begin
Expand Down
Loading