From 9851148768f34b6c209c7e0c2d7f2b8b2cf83648 Mon Sep 17 00:00:00 2001 From: odow Date: Tue, 15 Oct 2024 14:28:40 +1300 Subject: [PATCH 1/4] Fix sum(::AbstractArray{<:AbstractMutable}; dims) --- src/dispatch.jl | 11 +++++++++-- test/dispatch.jl | 12 ++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index a5a1c28..10dbc6d 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -13,8 +13,15 @@ abstract type AbstractMutable end -function Base.sum(a::AbstractArray{<:AbstractMutable}; kwargs...) - return operate(sum, a; kwargs...) +function Base.sum( + a::AbstractArray{T}; + dims = missing, + init = zero(promote_operation(+, T, T)), +) where {T<:AbstractMutable} + if !ismissing(dims) + return mapreduce(identity, Base.add_sum, a; dims, init) + end + return operate(sum, a; init) end # When doing `x'y` where the elements of `x` and/or `y` are arrays, redirecting diff --git a/test/dispatch.jl b/test/dispatch.jl index 722f0a0..5a8b193 100644 --- a/test/dispatch.jl +++ b/test/dispatch.jl @@ -131,3 +131,15 @@ end # MA is at least 10-times better than no MA for this example @test 10 * with_init < no_ma end + +@testset "sum_with_init_and_dims" begin + x = reshape(convert(Vector{DummyBigInt}, 1:12), 3, 4) + X = reshape(1:12, 3, 4) + for dims in (1, 2, :, 1:2, (1, 2)) + # Without (; init) + @test MA.isequal_canonical(sum(x; dims), DummyBigInt.(sum(X; dims))) + # With (; init) + y = sum(x; init = DummyBigInt(0), dims) + @test MA.isequal_canonical(y, DummyBigInt.(sum(X; dims))) + end +end From 42fc75d5f5273db041c1302beceaf7b892f49e4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 15 Oct 2024 16:19:04 +0200 Subject: [PATCH 2/4] Add comments --- src/dispatch.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 10dbc6d..1db2dac 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -15,10 +15,13 @@ abstract type AbstractMutable end function Base.sum( a::AbstractArray{T}; - dims = missing, + dims = :, init = zero(promote_operation(+, T, T)), ) where {T<:AbstractMutable} if !ismissing(dims) + # We cannot use `mapreduce` with `add!!` instead of `Base.add_mul` like + # `operate(sum, ...)` because the same instance given at `init` is used + # at several places. return mapreduce(identity, Base.add_sum, a; dims, init) end return operate(sum, a; init) From f3695dd104c1f7819ea1e58af37e00daab4d1ebd Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 16 Oct 2024 10:33:09 +1300 Subject: [PATCH 3/4] Update src/dispatch.jl --- src/dispatch.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 1db2dac..46e6860 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -18,7 +18,7 @@ function Base.sum( dims = :, init = zero(promote_operation(+, T, T)), ) where {T<:AbstractMutable} - if !ismissing(dims) + if dims !== : # We cannot use `mapreduce` with `add!!` instead of `Base.add_mul` like # `operate(sum, ...)` because the same instance given at `init` is used # at several places. From f4bdeb36541f42a7f7994b8ef1daed84d4e71aa2 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 16 Oct 2024 10:48:26 +1300 Subject: [PATCH 4/4] Update src/dispatch.jl --- src/dispatch.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 46e6860..43c1fb4 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -18,7 +18,7 @@ function Base.sum( dims = :, init = zero(promote_operation(+, T, T)), ) where {T<:AbstractMutable} - if dims !== : + if dims !== Colon() # We cannot use `mapreduce` with `add!!` instead of `Base.add_mul` like # `operate(sum, ...)` because the same instance given at `init` is used # at several places.