Skip to content

Commit

Permalink
blub
Browse files Browse the repository at this point in the history
  • Loading branch information
carstenbauer committed Sep 25, 2024
1 parent 85e968d commit c3e695b
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions src/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ using Base: @propagate_inbounds
using Base.Threads: nthreads, @threads
using BangBang: append!!
using ChunkSplitters: ChunkSplitters, index_chunks, Consecutive
using ChunkSplitters.Internals: AbstractChunksIterator, IndexChunks
using ChunkSplitters.Internals: AbstractChunks, IndexChunks

const MaybeScheduler = Union{NotGiven, Scheduler, Symbol}

include("macro_impl.jl")

function auto_disable_chunking_warning()
@warn("You passed in a `<:AbstractChunksIterator` but also a scheduler that has "*
"chunking enabled. Will turn off internal chunking to proceed.\n"*
"To avoid this warning, turn off chunking (`chunking=false`).")
end
# function auto_disable_chunking_warning()
# @warn("You passed in a `<:AbstractChunks` but also a scheduler that has "*
# "chunking enabled. Will turn off internal chunking to proceed.\n"*
# "To avoid this warning, turn off chunking (`chunking=false`).")
# end

function _index_chunks(sched, arg)
C = chunking_mode(sched)
Expand Down Expand Up @@ -103,15 +103,14 @@ function _tmapreduce(f,
end
end

# DynamicScheduler: AbstractChunksIterator
# DynamicScheduler: AbstractChunks
function _tmapreduce(f,
op,
Arrs::Union{Tuple{AbstractChunksIterator{T}}, Tuple{ChunkSplitters.Internals.Enumerate{T}}},
Arrs::Union{Tuple{AbstractChunks{T}}, Tuple{ChunkSplitters.Internals.Enumerate{T}}},
::Type{OutputType},
scheduler::DynamicScheduler,
mapreduce_kwargs)::OutputType where {OutputType, T}
(; threadpool) = scheduler
chunking_enabled(scheduler) && auto_disable_chunking_warning()
tasks = map(only(Arrs)) do idcs
@spawn threadpool promise_task_local(f)(idcs)
end
Expand Down Expand Up @@ -151,14 +150,13 @@ function _tmapreduce(f,
end
end

# StaticScheduler: AbstractChunksIterator
# StaticScheduler: AbstractChunks
function _tmapreduce(f,
op,
Arrs::Tuple{AbstractChunksIterator{T}}, # we don't support multiple chunks for now
Arrs::Tuple{AbstractChunks{T}}, # we don't support multiple chunks for now
::Type{OutputType},
scheduler::StaticScheduler,
mapreduce_kwargs)::OutputType where {OutputType, T}
chunking_enabled(scheduler) && auto_disable_chunking_warning()
check_all_have_same_indices(Arrs)
chnks = only(Arrs)
nt = nthreads()
Expand Down Expand Up @@ -321,7 +319,7 @@ function tmap(f, ::Type{T}, A::AbstractArray, _Arrs::AbstractArray...; kwargs...
end

function tmap(f,
A::Union{AbstractArray, AbstractChunksIterator, ChunkSplitters.Internals.Enumerate},
A::Union{AbstractArray, AbstractChunks, ChunkSplitters.Internals.Enumerate},
_Arrs::AbstractArray...;
scheduler::MaybeScheduler = NotGiven(),
kwargs...)
Expand All @@ -334,9 +332,8 @@ function tmap(f,
_scheduler.split != Consecutive()
error("Only `split == Consecutive()` is supported because the parallel operation isn't commutative. (Scheduler: $_scheduler)")
end
if (A isa AbstractChunksIterator || A isa ChunkSplitters.Internals.Enumerate) &&
if (A isa AbstractChunks || A isa ChunkSplitters.Internals.Enumerate) &&
chunking_enabled(_scheduler)
auto_disable_chunking_warning()
if _scheduler isa DynamicScheduler
_scheduler = DynamicScheduler(;
threadpool = _scheduler.threadpool,
Expand Down Expand Up @@ -376,10 +373,10 @@ function _tmap(scheduler::DynamicScheduler{NoChunking},
reshape(v, size(A)...)
end

# w/o chunking (DynamicScheduler{NoChunking}): AbstractChunksIterator
# w/o chunking (DynamicScheduler{NoChunking}): AbstractChunks
function _tmap(scheduler::DynamicScheduler{NoChunking},
f,
A::Union{AbstractChunksIterator, ChunkSplitters.Internals.Enumerate},
A::Union{AbstractChunks, ChunkSplitters.Internals.Enumerate},
_Arrs::AbstractArray...)
(; threadpool) = scheduler
tasks = map(A) do idcs
Expand All @@ -388,10 +385,10 @@ function _tmap(scheduler::DynamicScheduler{NoChunking},
map(fetch, tasks)
end

# w/o chunking (StaticScheduler{NoChunking}): AbstractChunksIterator
# w/o chunking (StaticScheduler{NoChunking}): AbstractChunks
function _tmap(scheduler::StaticScheduler{NoChunking},
f,
A::AbstractChunksIterator,
A::AbstractChunks,
_Arrs::AbstractArray...)
nt = nthreads()
tasks = map(enumerate(A)) do (c, idcs)
Expand Down

0 comments on commit c3e695b

Please sign in to comment.