From a548208666b36f3b769704685730823ba742c26c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 10 Oct 2024 13:38:57 +0200 Subject: [PATCH] Allow init argument for sum --- src/dispatch.jl | 4 ++-- src/reduce.jl | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 2960906b..20aee65b 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -13,8 +13,8 @@ abstract type AbstractMutable end -function Base.sum(a::AbstractArray{<:AbstractMutable}) - return operate(sum, a) +function Base.sum(a::AbstractArray{<:AbstractMutable}; kws...) + return operate(sum, a; kws...) end # When doing `x'y` where the elements of `x` and/or `y` are arrays, redirecting diff --git a/src/reduce.jl b/src/reduce.jl index c1644470..9a909eee 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -50,11 +50,10 @@ function fused_map_reduce(op::F, args::Vararg{Any,N}) where {F<:Function,N} return accumulator end -function operate(::typeof(sum), a::AbstractArray) - return mapreduce( - identity, - add!!, - a; - init = zero(promote_operation(+, eltype(a), eltype(a))), - ) +function operate( + ::typeof(sum), + a::AbstractArray; + init = zero(promote_operation(+, eltype(a), eltype(a))), +) + return mapreduce(identity, add!!, a; init) end