diff --git a/src/dispatch.jl b/src/dispatch.jl index a5a1c28..10dbc6d 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -13,8 +13,15 @@ abstract type AbstractMutable end -function Base.sum(a::AbstractArray{<:AbstractMutable}; kwargs...) - return operate(sum, a; kwargs...) +function Base.sum( + a::AbstractArray{T}; + dims = missing, + init = zero(promote_operation(+, T, T)), +) where {T<:AbstractMutable} + if !ismissing(dims) + return mapreduce(identity, Base.add_sum, a; dims, init) + end + return operate(sum, a; init) end # When doing `x'y` where the elements of `x` and/or `y` are arrays, redirecting diff --git a/test/dispatch.jl b/test/dispatch.jl index 722f0a0..5a8b193 100644 --- a/test/dispatch.jl +++ b/test/dispatch.jl @@ -131,3 +131,15 @@ end # MA is at least 10-times better than no MA for this example @test 10 * with_init < no_ma end + +@testset "sum_with_init_and_dims" begin + x = reshape(convert(Vector{DummyBigInt}, 1:12), 3, 4) + X = reshape(1:12, 3, 4) + for dims in (1, 2, :, 1:2, (1, 2)) + # Without (; init) + @test MA.isequal_canonical(sum(x; dims), DummyBigInt.(sum(X; dims))) + # With (; init) + y = sum(x; init = DummyBigInt(0), dims) + @test MA.isequal_canonical(y, DummyBigInt.(sum(X; dims))) + end +end