From 381a59d320d9f23b5b0698f319dc69e3cd18e72f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 10 Oct 2024 23:13:53 +0200 Subject: [PATCH] Allow init argument for sum (#306) --- src/dispatch.jl | 4 ++-- src/reduce.jl | 13 ++++++------- test/dispatch.jl | 25 +++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 2960906b..a5a1c28f 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -13,8 +13,8 @@ abstract type AbstractMutable end -function Base.sum(a::AbstractArray{<:AbstractMutable}) - return operate(sum, a) +function Base.sum(a::AbstractArray{<:AbstractMutable}; kwargs...) + return operate(sum, a; kwargs...) end # When doing `x'y` where the elements of `x` and/or `y` are arrays, redirecting diff --git a/src/reduce.jl b/src/reduce.jl index c1644470..9a909eee 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -50,11 +50,10 @@ function fused_map_reduce(op::F, args::Vararg{Any,N}) where {F<:Function,N} return accumulator end -function operate(::typeof(sum), a::AbstractArray) - return mapreduce( - identity, - add!!, - a; - init = zero(promote_operation(+, eltype(a), eltype(a))), - ) +function operate( + ::typeof(sum), + a::AbstractArray; + init = zero(promote_operation(+, eltype(a), eltype(a))), +) + return mapreduce(identity, add!!, a; init) end diff --git a/test/dispatch.jl b/test/dispatch.jl index 84ba729d..722f0a0a 100644 --- a/test/dispatch.jl +++ b/test/dispatch.jl @@ -106,3 +106,28 @@ end end end end + +function non_mutable_sum_pr306(x) + y = zero(eltype(x)) + for xi in x + y += xi + end + return y +end + +@testset "sum_with_init" begin + x = convert(Vector{DummyBigInt}, 1:100) + # compilation + @allocated sum(x) + @allocated sum(x; init = DummyBigInt(0)) + @allocated non_mutable_sum_pr306(x) + # now test actual allocations + no_init = @allocated sum(x) + with_init = @allocated sum(x; init = DummyBigInt(0)) + no_ma = @allocated non_mutable_sum_pr306(x) + # There's an additional 16 bytes for kwarg version. Upper bound by 40 to be + # safe between Julia versions + @test with_init <= no_init + 40 + # MA is at least 10-times better than no MA for this example + @test 10 * with_init < no_ma +end