diff --git a/src/implementation.jl b/src/implementation.jl index 0cb461dc..f7dc5a2f 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -14,17 +14,17 @@ using Base: @propagate_inbounds using Base.Threads: nthreads, @threads using BangBang: append!! using ChunkSplitters: ChunkSplitters, index_chunks, Consecutive -using ChunkSplitters.Internals: AbstractChunksIterator, IndexChunks +using ChunkSplitters.Internals: AbstractChunks, IndexChunks const MaybeScheduler = Union{NotGiven, Scheduler, Symbol} include("macro_impl.jl") -function auto_disable_chunking_warning() - @warn("You passed in a `<:AbstractChunksIterator` but also a scheduler that has "* - "chunking enabled. Will turn off internal chunking to proceed.\n"* - "To avoid this warning, turn off chunking (`chunking=false`).") -end +# function auto_disable_chunking_warning() +# @warn("You passed in a `<:AbstractChunks` but also a scheduler that has "* +# "chunking enabled. Will turn off internal chunking to proceed.\n"* +# "To avoid this warning, turn off chunking (`chunking=false`).") +# end function _index_chunks(sched, arg) C = chunking_mode(sched) @@ -103,15 +103,14 @@ function _tmapreduce(f, end end -# DynamicScheduler: AbstractChunksIterator +# DynamicScheduler: AbstractChunks function _tmapreduce(f, op, - Arrs::Union{Tuple{AbstractChunksIterator{T}}, Tuple{ChunkSplitters.Internals.Enumerate{T}}}, + Arrs::Union{Tuple{AbstractChunks{T}}, Tuple{ChunkSplitters.Internals.Enumerate{T}}}, ::Type{OutputType}, scheduler::DynamicScheduler, mapreduce_kwargs)::OutputType where {OutputType, T} (; threadpool) = scheduler - chunking_enabled(scheduler) && auto_disable_chunking_warning() tasks = map(only(Arrs)) do idcs @spawn threadpool promise_task_local(f)(idcs) end @@ -151,14 +150,13 @@ function _tmapreduce(f, end end -# StaticScheduler: AbstractChunksIterator +# StaticScheduler: AbstractChunks function _tmapreduce(f, op, - Arrs::Tuple{AbstractChunksIterator{T}}, # we don't support multiple chunks for now + Arrs::Tuple{AbstractChunks{T}}, # we don't support multiple chunks for now ::Type{OutputType}, scheduler::StaticScheduler, mapreduce_kwargs)::OutputType where {OutputType, T} - chunking_enabled(scheduler) && auto_disable_chunking_warning() check_all_have_same_indices(Arrs) chnks = only(Arrs) nt = nthreads() @@ -321,7 +319,7 @@ function tmap(f, ::Type{T}, A::AbstractArray, _Arrs::AbstractArray...; kwargs... end function tmap(f, - A::Union{AbstractArray, AbstractChunksIterator, ChunkSplitters.Internals.Enumerate}, + A::Union{AbstractArray, AbstractChunks, ChunkSplitters.Internals.Enumerate}, _Arrs::AbstractArray...; scheduler::MaybeScheduler = NotGiven(), kwargs...) @@ -334,9 +332,8 @@ function tmap(f, _scheduler.split != Consecutive() error("Only `split == Consecutive()` is supported because the parallel operation isn't commutative. (Scheduler: $_scheduler)") end - if (A isa AbstractChunksIterator || A isa ChunkSplitters.Internals.Enumerate) && + if (A isa AbstractChunks || A isa ChunkSplitters.Internals.Enumerate) && chunking_enabled(_scheduler) - auto_disable_chunking_warning() if _scheduler isa DynamicScheduler _scheduler = DynamicScheduler(; threadpool = _scheduler.threadpool, @@ -376,10 +373,10 @@ function _tmap(scheduler::DynamicScheduler{NoChunking}, reshape(v, size(A)...) end -# w/o chunking (DynamicScheduler{NoChunking}): AbstractChunksIterator +# w/o chunking (DynamicScheduler{NoChunking}): AbstractChunks function _tmap(scheduler::DynamicScheduler{NoChunking}, f, - A::Union{AbstractChunksIterator, ChunkSplitters.Internals.Enumerate}, + A::Union{AbstractChunks, ChunkSplitters.Internals.Enumerate}, _Arrs::AbstractArray...) (; threadpool) = scheduler tasks = map(A) do idcs @@ -388,10 +385,10 @@ function _tmap(scheduler::DynamicScheduler{NoChunking}, map(fetch, tasks) end -# w/o chunking (StaticScheduler{NoChunking}): AbstractChunksIterator +# w/o chunking (StaticScheduler{NoChunking}): AbstractChunks function _tmap(scheduler::StaticScheduler{NoChunking}, f, - A::AbstractChunksIterator, + A::AbstractChunks, _Arrs::AbstractArray...) nt = nthreads() tasks = map(enumerate(A)) do (c, idcs)