From 190dd51442a1eae4c32ecbc2e6cbbc51deb0aaf8 Mon Sep 17 00:00:00 2001 From: Romeo Valentin Date: Sun, 28 Jan 2024 16:35:10 -0800 Subject: [PATCH 1/5] Implement `tcollect`, which essentially just calls `tmap` with a Generator input. --- src/ThreadsBasics.jl | 18 +++++++++++++++++- src/implementation.jl | 8 +++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/ThreadsBasics.jl b/src/ThreadsBasics.jl index 55a9a1f4..87bb358e 100644 --- a/src/ThreadsBasics.jl +++ b/src/ThreadsBasics.jl @@ -4,7 +4,7 @@ module ThreadsBasics using StableTasks: @spawn using ChunkSplitters: chunks -export chunks, treduce, tmapreduce, treducemap, tmap, tmap!, tforeach +export chunks, treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect """ tmapreduce(f, op, A::AbstractArray; @@ -163,6 +163,22 @@ of `out[i] = f(A[i])` for each index `i` of `A` and `out`. """ function tmap! end +""" + tcollect(::Type{OutputType}, gen::Base.Generator{AbstractArray, F}; + nchunks::Int = 2 * nthreads(), + split::Symbol = :batch, + schedule::Symbol =:dynamic) + +A multithreaded function like `Base.collect`. Essentially just calls `tmap` on the generator function and inputs. + +## Keyword arguments: + +- `nchunks::Int` (default 2 * nthreads()) is passed to `ChunkSplitters.chunks` to inform it how many pieces of data should be worked on in parallel. Greater `nchunks` typically helps with [load balancing](https://en.wikipedia.org/wiki/Load_balancing_(computing)), but at the expense of creating more overhead. +- `split::Symbol` (default `:batch`) is passed to `ChunkSplitters.chunks` to inform it if the data chunks to be worked on should be contiguous (:batch) or shuffled (:scatter). If `scatter` is chosen, then your reducing operator `op` **must** be [commutative](https://en.wikipedia.org/wiki/Commutative_property) in addition to being associative, or you could get incorrect results! +- `schedule::Symbol` either `:dynamic` or `:static` (default `:dynamic`), determines how the parallel portions of the calculation are scheduled. `:dynamic` scheduling is generally preferred since it is more flexible and better at load balancing, but `:static` scheduling can sometimes be more performant when the time it takes to complete a step of the calculation is highly uniform, and no other parallel functions are running at the same time. +""" +function tcollect end + include("implementation.jl") diff --git a/src/implementation.jl b/src/implementation.jl index 7402060a..8db2d706 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -1,6 +1,6 @@ module Implementation -import ThreadsBasics: treduce, tmapreduce, treducemap, tforeach, tmap, tmap! +import ThreadsBasics: treduce, tmapreduce, treducemap, tforeach, tmap, tmap!, tcollect using ThreadsBasics: chunks, @spawn using Base: @propagate_inbounds @@ -67,5 +67,11 @@ end out end +#------------------------------------------------------------- + +function tcollect(::Type{T}, gen::Base.Generator{<:AbstractArray, F}; kwargs...) where {T, F} + tmap(gen.f, T, gen.iter; kwargs...) +end + end # module Implementation From 1d57ed374060a4827f3cf78f6ab85c593d1cf62a Mon Sep 17 00:00:00 2001 From: Romeo Valentin Date: Sun, 28 Jan 2024 16:42:47 -0800 Subject: [PATCH 2/5] Add test for `tcollect`. --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index f0e459c1..ca4c774a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,7 @@ using Test, ThreadsBasics @test all(tmap(f, Any, itr; kwargs...) .~ map_f_itr) RT = Core.Compiler.return_type(f, Tuple{eltype(itr)}) @test tmap(f, RT, itr; kwargs...) ~ map_f_itr + @test tcollect(RT, (f(x) for x in itr); kwargs...) ~ map_f_itr end end end From b73a38bf4ff66ebb067aa1685f5978128cbab43e Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 29 Jan 2024 10:21:54 +0100 Subject: [PATCH 3/5] also add array method for `tcollect`. --- src/implementation.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/implementation.jl b/src/implementation.jl index 8db2d706..a19b36a2 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -72,6 +72,7 @@ end function tcollect(::Type{T}, gen::Base.Generator{<:AbstractArray, F}; kwargs...) where {T, F} tmap(gen.f, T, gen.iter; kwargs...) end +tcollect(::Type{T}, A; kwargs...) = tmap(identity, T, A; kwargs...) end # module Implementation From e6b83a0a6d4be7dedf5462b76b1e6727f119b74f Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 29 Jan 2024 10:23:42 +0100 Subject: [PATCH 4/5] Update runtests.jl --- test/runtests.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index ca4c774a..4f25594a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,9 +23,14 @@ using Test, ThreadsBasics map_f_itr = map(f, itr) @test all(tmap(f, Any, itr; kwargs...) .~ map_f_itr) + @test all(tcollect(Any, (f(x) for x in itr); kwargs...) .~ map_f_itr) + @test all(tcollect(Any, f.(itr); kwargs...) .~ map_f_itr) + RT = Core.Compiler.return_type(f, Tuple{eltype(itr)}) + @test tmap(f, RT, itr; kwargs...) ~ map_f_itr @test tcollect(RT, (f(x) for x in itr); kwargs...) ~ map_f_itr + @test tcollect(RT, f.(itr); kwargs...) ~ map_f_itr end end end From 15aebddcd9fa6102fefd94964d67f0a8db63670c Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 29 Jan 2024 10:36:14 +0100 Subject: [PATCH 5/5] fix typo --- src/implementation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementation.jl b/src/implementation.jl index a19b36a2..47d8a94f 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -72,7 +72,7 @@ end function tcollect(::Type{T}, gen::Base.Generator{<:AbstractArray, F}; kwargs...) where {T, F} tmap(gen.f, T, gen.iter; kwargs...) end -tcollect(::Type{T}, A; kwargs...) = tmap(identity, T, A; kwargs...) +tcollect(::Type{T}, A; kwargs...) where {T} = tmap(identity, T, A; kwargs...) end # module Implementation