From 511ec868ee2e05888818c81869f34725c7845a50 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Wed, 13 Mar 2024 11:17:35 +0100 Subject: [PATCH] fix for `GreedyScheduler` in tests (#80) --- src/implementation.jl | 3 --- test/runtests.jl | 10 +++++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/implementation.jl b/src/implementation.jl index 2e05e4e0..e73aceb0 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -386,9 +386,6 @@ end _Arrs::AbstractArray...; scheduler::Scheduler = DynamicScheduler(), kwargs...) - if hasfield(typeof(scheduler), :split) && scheduler.split != :batch - error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $scheduler)") - end Arrs = (A, _Arrs...) if scheduler isa SerialScheduler map!(f, out, Arrs...; kwargs...) diff --git a/test/runtests.jl b/test/runtests.jl index 3eee363b..20b7f315 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,6 @@ using Test, OhMyThreads using OhMyThreads: TaskLocalValue, WithTaskLocals, @fetch, promise_task_local - - sets_to_test = [(~ = isapprox, f = sin ∘ *, op = +, itrs = (rand(ComplexF64, 10, 10), rand(-10:10, 10, 10)), init = complex(0.0)) @@ -12,11 +10,13 @@ sets_to_test = [(~ = isapprox, f = sin ∘ *, op = +, itrs = ([1 => "a", 2 => "b", 3 => "c", 4 => "d", 5 => "e"],), init = "")] +ChunkedGreedy(;kwargs...) = GreedyScheduler(;kwargs...) + @testset "Basics" begin for (; ~, f, op, itrs, init) in sets_to_test @testset "f=$f, op=$op, itrs::$(typeof(itrs))" begin @testset for sched in ( - StaticScheduler, DynamicScheduler, GreedyScheduler, DynamicScheduler{OhMyThreads.Schedulers.NoChunking}, SerialScheduler) + StaticScheduler, DynamicScheduler, GreedyScheduler, DynamicScheduler{OhMyThreads.Schedulers.NoChunking}, SerialScheduler, ChunkedGreedy) @testset for split in (:batch, :scatter) for nchunks in (1, 2, 6) if sched == GreedyScheduler @@ -30,7 +30,7 @@ sets_to_test = [(~ = isapprox, f = sin ∘ *, op = +, end kwargs = (; scheduler) - if (split == :scatter || sched == GreedyScheduler) || op ∉ (vcat, *) + if (split == :scatter || sched ∈ (GreedyScheduler, ChunkedGreedy)) || op ∉ (vcat, *) # scatter and greedy only works for commutative operators! else mapreduce_f_op_itr = mapreduce(f, op, itrs...) @@ -51,7 +51,7 @@ sets_to_test = [(~ = isapprox, f = sin ∘ *, op = +, @test tcollect(RT, (f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr @test tcollect(RT, f.(itrs...); kwargs...) ~ map_f_itr - if sched !== GreedyScheduler + if sched ∉ (GreedyScheduler, ChunkedGreedy) @test tmap(f, itrs...; kwargs...) ~ map_f_itr @test tcollect((f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr @test tcollect(f.(itrs...); kwargs...) ~ map_f_itr