Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tcollect functionality. #1

Merged
merged 5 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/ThreadsBasics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module ThreadsBasics
using StableTasks: @spawn
using ChunkSplitters: chunks

export chunks, treduce, tmapreduce, treducemap, tmap, tmap!, tforeach
export chunks, treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect

"""
tmapreduce(f, op, A::AbstractArray;
Expand Down Expand Up @@ -163,6 +163,22 @@ of `out[i] = f(A[i])` for each index `i` of `A` and `out`.
"""
function tmap! end

"""
tcollect(::Type{OutputType}, gen::Base.Generator{AbstractArray, F};
nchunks::Int = 2 * nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic)

A multithreaded function like `Base.collect`. Essentially just calls `tmap` on the generator function and inputs.

## Keyword arguments:

- `nchunks::Int` (default 2 * nthreads()) is passed to `ChunkSplitters.chunks` to inform it how many pieces of data should be worked on in parallel. Greater `nchunks` typically helps with [load balancing](https://en.wikipedia.org/wiki/Load_balancing_(computing)), but at the expense of creating more overhead.
- `split::Symbol` (default `:batch`) is passed to `ChunkSplitters.chunks` to inform it if the data chunks to be worked on should be contiguous (:batch) or shuffled (:scatter). If `scatter` is chosen, then your reducing operator `op` **must** be [commutative](https://en.wikipedia.org/wiki/Commutative_property) in addition to being associative, or you could get incorrect results!
- `schedule::Symbol` either `:dynamic` or `:static` (default `:dynamic`), determines how the parallel portions of the calculation are scheduled. `:dynamic` scheduling is generally preferred since it is more flexible and better at load balancing, but `:static` scheduling can sometimes be more performant when the time it takes to complete a step of the calculation is highly uniform, and no other parallel functions are running at the same time.
"""
function tcollect end


include("implementation.jl")

Expand Down
9 changes: 8 additions & 1 deletion src/implementation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Implementation

import ThreadsBasics: treduce, tmapreduce, treducemap, tforeach, tmap, tmap!
import ThreadsBasics: treduce, tmapreduce, treducemap, tforeach, tmap, tmap!, tcollect

using ThreadsBasics: chunks, @spawn
using Base: @propagate_inbounds
Expand Down Expand Up @@ -67,5 +67,12 @@ end
out
end

#-------------------------------------------------------------

function tcollect(::Type{T}, gen::Base.Generator{<:AbstractArray, F}; kwargs...) where {T, F}
tmap(gen.f, T, gen.iter; kwargs...)
end
tcollect(::Type{T}, A; kwargs...) where {T} = tmap(identity, T, A; kwargs...)


end # module Implementation
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ using Test, ThreadsBasics

map_f_itr = map(f, itr)
@test all(tmap(f, Any, itr; kwargs...) .~ map_f_itr)
@test all(tcollect(Any, (f(x) for x in itr); kwargs...) .~ map_f_itr)
@test all(tcollect(Any, f.(itr); kwargs...) .~ map_f_itr)

RT = Core.Compiler.return_type(f, Tuple{eltype(itr)})

@test tmap(f, RT, itr; kwargs...) ~ map_f_itr
@test tcollect(RT, (f(x) for x in itr); kwargs...) ~ map_f_itr
@test tcollect(RT, f.(itr); kwargs...) ~ map_f_itr
end
end
end
Expand Down
Loading