From a60f111b9df59b95ec8da1761c38ca78d16fe388 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 28 Jun 2023 17:42:13 +0200 Subject: [PATCH] operate! with add_dot --- src/implementations/BigInt.jl | 2 +- src/interface.jl | 2 +- src/reduce.jl | 17 +++++++++++------ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/implementations/BigInt.jl b/src/implementations/BigInt.jl index 36aa8c45..c3b5dc4c 100644 --- a/src/implementations/BigInt.jl +++ b/src/implementations/BigInt.jl @@ -95,7 +95,7 @@ function operate_to!( return operate!(op, output, c...) end -function operate!(op::Function, x::BigInt, args::Vararg{Any,N}) where {N} +function operate_fallback!(op::Function, x::BigInt, args::Vararg{Any,N}) where {N} return operate_to!(x, op, x, args...) end diff --git a/src/interface.jl b/src/interface.jl index 5b3452ee..40e088c4 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -686,7 +686,7 @@ function promote_operation_fallback( ::Type{<:AbstractArray{A}}, ::Type{<:AbstractArray{B}}, ) where {A,B} - C = promote_operation(*, A, B) + C = promote_operation(LinearAlgebra.dot, A, B) return promote_operation(+, C, C) end diff --git a/src/reduce.jl b/src/reduce.jl index c1644470..3067545a 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -38,16 +38,21 @@ end _concrete_eltype(x) = isempty(x) ? eltype(x) : typeof(first(x)) -function fused_map_reduce(op::F, args::Vararg{Any,N}) where {F<:Function,N} - _check_same_length(args...) + +function operate!(op::typeof(add_dot), output, args::Vararg{Any,N}) where {N} T = promote_map_reduce(op, _concrete_eltype.(args)...) - accumulator = neutral_element(reduce_op(op), T) buffer = buffer_for(op, T, eltype.(args)...) for I in zip(eachindex.(args)...) - accumulator = - buffered_operate!!(buffer, op, accumulator, getindex.(args, I)...) + output = buffered_operate!!(buffer, op, output, getindex.(args, I)...) end - return accumulator + return output +end + +function fused_map_reduce(op::F, args::Vararg{Any,N}) where {F<:Function,N} + _check_same_length(args...) + T = promote_map_reduce(op, _concrete_eltype.(args)...) + accumulator = neutral_element(reduce_op(op), T) + return operate!(op, accumulator, args...) end function operate(::typeof(sum), a::AbstractArray)