From 66bc2926b455467a3bdbb6229431bb8779baeb23 Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Wed, 18 Sep 2024 08:42:16 +0200 Subject: [PATCH] tmap for enumerate(chunks(...)) --- src/implementation.jl | 7 ++++--- test/runtests.jl | 8 ++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/implementation.jl b/src/implementation.jl index d9e3791..cd7227c 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -320,7 +320,7 @@ function tmap(f, ::Type{T}, A::AbstractArray, _Arrs::AbstractArray...; kwargs... end function tmap(f, - A::Union{AbstractArray, ChunkSplitters.Chunk}, + A::Union{AbstractArray, ChunkSplitters.Chunk, ChunkSplitters.Enumerate}, _Arrs::AbstractArray...; scheduler::MaybeScheduler = NotGiven(), kwargs...) @@ -333,7 +333,8 @@ function tmap(f, _scheduler.split != :batch error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $_scheduler)") end - if A isa ChunkSplitters.Chunk && chunking_enabled(_scheduler) + if (A isa ChunkSplitters.Chunk || A isa ChunkSplitters.Enumerate) && + chunking_enabled(_scheduler) auto_disable_chunking_warning() if _scheduler isa DynamicScheduler _scheduler = DynamicScheduler(; @@ -377,7 +378,7 @@ end # w/o chunking (DynamicScheduler{NoChunking}): ChunkSplitters.Chunk function _tmap(scheduler::DynamicScheduler{NoChunking}, f, - A::ChunkSplitters.Chunk, + A::Union{ChunkSplitters.Chunk, ChunkSplitters.Enumerate}, _Arrs::AbstractArray...) (; threadpool) = scheduler tasks = map(A) do idcs diff --git a/test/runtests.jl b/test/runtests.jl index e638404..2fcd00b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -92,6 +92,9 @@ end; @test tmapreduce(+, enumerate(chunks(data; size=5)); chunking=false) do (i, idcs) [i, sum(@view(data[idcs]))] end == [sum(1:20), sum(data)] + @test tmap(enumerate(chunks(data; n=5)); chunking=false) do (i, idcs) + [i, idcs] + end == [[1, 1:20], [2, 21:40], [3, 41:60], [4, 61:80], [5, 81:100]] end; @testset "macro API" begin @@ -267,6 +270,11 @@ end; @set reducer = + [i, sum(@view(data[idcs]))] end) == [sum(1:20), sum(data)] + @test @tasks(for (i, idcs) in enumerate(chunks(1:100; n=5)) + @set chunking=false + @set collect=true + [i, idcs] + end) == [[1, 1:20], [2, 21:40], [3, 41:60], [4, 61:80], [5, 81:100]] end; @testset "WithTaskLocals" begin