Skip to content

Commit

Permalink
Fix implementation of mul! for AbstractMatrix and AbstractVector
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Dec 21, 2023
1 parent 69a675d commit 5d1541f
Showing 1 changed file with 28 additions and 30 deletions.
58 changes: 28 additions & 30 deletions src/implementations/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,46 +298,44 @@ function _dim_check(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
return
end

function _add_mul_array(buffer, C::Vector, A::AbstractMatrix, B::AbstractVector)
Astride = size(A, 1)
# We need a buffer to hold the intermediate multiplication.
@inbounds begin
for k in eachindex(B)
aoffs = (k - 1) * Astride
b = B[k]
for i in Base.OneTo(size(A, 1))
C[i] = buffered_operate!!(buffer, add_mul, C[i], A[aoffs+i], b)
end
end
end # @inbounds
return C
end

# This is incorrect if `C` is `LinearAlgebra.Symmetric` as we modify twice the
# same diagonal element.
function _add_mul_array(buffer, C::Matrix, A::AbstractMatrix, B::AbstractMatrix)
@inbounds begin
for i in 1:size(A, 1), j in 1:size(B, 2)
Ctmp = C[i, j]
for k in 1:size(A, 2)
Ctmp =
buffered_operate!!(buffer, add_mul, Ctmp, A[i, k], B[k, j])
end
C[i, j] = Ctmp
function buffered_operate!(
buffer,
::typeof(add_mul),
C::Vector,
A::AbstractMatrix,
B::AbstractVector,
)
_dim_check(C, A, B)
for (ci, ai) in zip(axes(C, 1), axes(A, 1))
for (aj, bj) in zip(axes(A, 2), axes(B, 1))
C[ci] = buffered_operate!!(buffer, add_mul, C[ci], A[ai, aj], B[bj])
end
end # @inbounds
end
return C
end

function buffered_operate!(
buffer,
::typeof(add_mul),
C::VecOrMat,
C::Matrix,
A::AbstractMatrix,
B::AbstractVecOrMat,
B::AbstractMatrix,
)
_dim_check(C, A, B)
return _add_mul_array(buffer, C, A, B)
for (ci, ai) in zip(axes(C, 1), axes(A, 1))
for (cj, bj) in zip(axes(C, 2), axes(B, 2))
for (aj, bi) in zip(axes(A, 2), axes(B, 1))
C[ci, cj] = buffered_operate!!(
buffer,
add_mul,
C[ci, cj],
A[ai, aj],
B[bi, bj],
)
end
end
end
return C
end

function buffer_for(
Expand Down

0 comments on commit 5d1541f

Please sign in to comment.