From 85f5ff262ea7927fd6c418b287068a2cecdf244a Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Tue, 9 Jan 2024 21:29:51 +1300 Subject: [PATCH] Fix implementation of mul! for AbstractMatrix and AbstractVector (#252) --- src/implementations/LinearAlgebra.jl | 93 +++++++++------------------- test/matmul.jl | 2 + 2 files changed, 30 insertions(+), 65 deletions(-) diff --git a/src/implementations/LinearAlgebra.jl b/src/implementations/LinearAlgebra.jl index 5415d22e..12739314 100644 --- a/src/implementations/LinearAlgebra.jl +++ b/src/implementations/LinearAlgebra.jl @@ -223,41 +223,6 @@ function promote_array_mul( return Vector{promote_sum_mul(S, T)} end -################################################################################ -# We roll our own matmul here (instead of using Julia's generic fallbacks) -# because doing so allows us to accumulate the expressions for the inner loops -# in-place. -# Additionally, Julia's generic fallbacks can be finnicky when your array -# elements aren't `<:Number`. - -# This method of `mul!` is adapted from upstream Julia. Note that we -# confuse transpose with adjoint. -#= -> Copyright (c) 2009-2018: Jeff Bezanson, Stefan Karpinski, Viral B. Shah, -> and other contributors: -> -> https://github.com/JuliaLang/julia/contributors -> -> Permission is hereby granted, free of charge, to any person obtaining -> a copy of this software and associated documentation files (the -> "Software"), to deal in the Software without restriction, including -> without limitation the rights to use, copy, modify, merge, publish, -> distribute, sublicense, and/or sell copies of the Software, and to -> permit persons to whom the Software is furnished to do so, subject to -> the following conditions: -> -> The above copyright notice and this permission notice shall be -> included in all copies or substantial portions of the Software. -> -> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -> EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -> MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -> NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -> LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -> OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -> WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -=# - function _dim_check(C::AbstractVector, A::AbstractMatrix, B::AbstractVector) mB = length(B) mA, nA = size(A) @@ -298,46 +263,44 @@ function _dim_check(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) return end -function _add_mul_array(buffer, C::Vector, A::AbstractMatrix, B::AbstractVector) - Astride = size(A, 1) - # We need a buffer to hold the intermediate multiplication. - @inbounds begin - for k in eachindex(B) - aoffs = (k - 1) * Astride - b = B[k] - for i in Base.OneTo(size(A, 1)) - C[i] = buffered_operate!!(buffer, add_mul, C[i], A[aoffs+i], b) - end - end - end # @inbounds - return C -end - -# This is incorrect if `C` is `LinearAlgebra.Symmetric` as we modify twice the -# same diagonal element. -function _add_mul_array(buffer, C::Matrix, A::AbstractMatrix, B::AbstractMatrix) - @inbounds begin - for i in 1:size(A, 1), j in 1:size(B, 2) - Ctmp = C[i, j] - for k in 1:size(A, 2) - Ctmp = - buffered_operate!!(buffer, add_mul, Ctmp, A[i, k], B[k, j]) - end - C[i, j] = Ctmp +function buffered_operate!( + buffer, + ::typeof(add_mul), + C::Vector, + A::AbstractMatrix, + B::AbstractVector, +) + _dim_check(C, A, B) + for (ci, ai) in zip(axes(C, 1), axes(A, 1)) + for (aj, bj) in zip(axes(A, 2), axes(B, 1)) + C[ci] = buffered_operate!!(buffer, add_mul, C[ci], A[ai, aj], B[bj]) end - end # @inbounds + end return C end function buffered_operate!( buffer, ::typeof(add_mul), - C::VecOrMat, + C::Matrix, A::AbstractMatrix, - B::AbstractVecOrMat, + B::AbstractMatrix, ) _dim_check(C, A, B) - return _add_mul_array(buffer, C, A, B) + for (ci, ai) in zip(axes(C, 1), axes(A, 1)) + for (cj, bj) in zip(axes(C, 2), axes(B, 2)) + for (aj, bi) in zip(axes(A, 2), axes(B, 1)) + C[ci, cj] = buffered_operate!!( + buffer, + add_mul, + C[ci, cj], + A[ai, aj], + B[bi, bj], + ) + end + end + end + return C end function buffer_for( diff --git a/test/matmul.jl b/test/matmul.jl index 0fd968e9..2c26a51f 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -304,6 +304,8 @@ Base.size(x::Issue65Matrix) = size(x.x) Base.getindex(x::Issue65Matrix, args...) = getindex(x.x, args...) Base.axes(x::Issue65Matrix, n) = Issue65OneTo(size(x.x, n)) Base.convert(::Type{Base.OneTo}, x::Issue65OneTo) = Base.OneTo(x.N) +Base.iterate(x::Issue65OneTo) = iterate(Base.OneTo(x.N)) +Base.iterate(x::Issue65OneTo, arg) = iterate(Base.OneTo(x.N), arg) @testset "Issue #65" begin x = [1.0 2.0; 3.0 4.0]