diff --git a/src/implementation.jl b/src/implementation.jl index 2e6ad5cb..d9e3791d 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -105,11 +105,11 @@ end # DynamicScheduler: ChunkSplitters.Chunk function _tmapreduce(f, op, - Arrs::Tuple{ChunkSplitters.Chunk{T}}, # we don't support multiple chunks for now + Arrs::Union{Tuple{ChunkSplitters.Chunk{T}}, Tuple{ChunkSplitters.Enumerate{T}}}, ::Type{OutputType}, scheduler::DynamicScheduler, mapreduce_kwargs)::OutputType where {OutputType, T} - (; nchunks, split, threadpool) = scheduler + (; threadpool) = scheduler chunking_enabled(scheduler) && auto_disable_chunking_warning() tasks = map(only(Arrs)) do idcs @spawn threadpool promise_task_local(f)(idcs) diff --git a/test/runtests.jl b/test/runtests.jl index ee28d761..e6384043 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -83,6 +83,15 @@ end; @test isnothing(tforeach(x -> sin.(x), chnks; scheduler)) end end + + # enumerate(chunks) + data = 1:100 + @test tmapreduce(+, enumerate(chunks(data; n=5)); chunking=false) do (i, idcs) + [i, sum(@view(data[idcs]))] + end == [sum(1:5), sum(data)] + @test tmapreduce(+, enumerate(chunks(data; size=5)); chunking=false) do (i, idcs) + [i, sum(@view(data[idcs]))] + end == [sum(1:20), sum(data)] end; @testset "macro API" begin @@ -246,6 +255,18 @@ end; @set reducer = + C.x end) == 10 * var + + # enumerate(chunks) + data = collect(1:100) + @test @tasks(for (i, idcs) in enumerate(chunks(data; n=5)) + @set reducer = + + @set chunking = false + [i, sum(@view(data[idcs]))] + end) == [sum(1:5), sum(data)] + @test @tasks(for (i, idcs) in enumerate(chunks(data; size=5)) + @set reducer = + + [i, sum(@view(data[idcs]))] + end) == [sum(1:20), sum(data)] end; @testset "WithTaskLocals" begin