From 229103aaca799bcd3d04fe44ce56c35ceab521c9 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Thu, 30 Nov 2023 14:35:10 +1300 Subject: [PATCH] Add dispatch for LinearAlgebra.dot for Symmetric and Hermitian matrices (#248) Co-authored-by: araujoms --- src/dispatch.jl | 26 +++++++------------------- test/dummy.jl | 2 ++ test/rewrite.jl | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 19 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index a4a1be50..c1f4cf43 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -51,25 +51,13 @@ function LinearAlgebra._dot_nonrecursive( return fused_map_reduce(add_mul, lhs, rhs) end -function LinearAlgebra.dot( - lhs::AbstractArray{<:AbstractMutable}, - rhs::AbstractArray, -) - return operate(LinearAlgebra.dot, lhs, rhs) -end - -function LinearAlgebra.dot( - lhs::AbstractArray, - rhs::AbstractArray{<:AbstractMutable}, -) - return operate(LinearAlgebra.dot, lhs, rhs) -end - -function LinearAlgebra.dot( - lhs::AbstractArray{<:AbstractMutable}, - rhs::AbstractArray{<:AbstractMutable}, -) - return operate(LinearAlgebra.dot, lhs, rhs) +for A in (LinearAlgebra.Symmetric, LinearAlgebra.Hermitian, AbstractArray) + B = A{<:AbstractMutable} + @eval begin + LinearAlgebra.dot(x::$A, y::$B) = operate(LinearAlgebra.dot, x, y) + LinearAlgebra.dot(x::$B, y::$A) = operate(LinearAlgebra.dot, x, y) + LinearAlgebra.dot(x::$B, y::$B) = operate(LinearAlgebra.dot, x, y) + end end # Special-case because the the base version wants to do diff --git a/test/dummy.jl b/test/dummy.jl index 0ed4082e..c84e55bb 100644 --- a/test/dummy.jl +++ b/test/dummy.jl @@ -31,6 +31,8 @@ Base.copy(x::DummyBigInt) = x MA.mutable_copy(x::DummyBigInt) = DummyBigInt(MA.mutable_copy(x.data)) LinearAlgebra.symmetric_type(::Type{DummyBigInt}) = DummyBigInt LinearAlgebra.symmetric(x::DummyBigInt, ::Symbol) = x +LinearAlgebra.hermitian_type(::Type{DummyBigInt}) = DummyBigInt +LinearAlgebra.hermitian(x::DummyBigInt, ::Symbol) = x LinearAlgebra.dot(x::DummyBigInt, y::DummyBigInt) = x * y function LinearAlgebra.dot( x::DummyBigInt, diff --git a/test/rewrite.jl b/test/rewrite.jl index dcdb32e2..6b3419df 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -168,3 +168,36 @@ Base.getindex(x::_KwargRef; i) = x.data[i] x = _KwargRef(Dict(i => i + 1 for i in 2:4)) @test MA.@rewrite(sum(x[i = j] for j in 2:4)) == 12 end + +@testset "dispatch_dot" begin + # Symmetric + x = DummyBigInt[1 2; 2 3] + y = LinearAlgebra.Symmetric(x) + @test MA.isequal_canonical( + LinearAlgebra.dot(x, y), + MA.operate(LinearAlgebra.dot, x, y), + ) + a = @allocated LinearAlgebra.dot(x, y) + b = @allocated MA.operate(LinearAlgebra.dot, x, y) + @test a == b + # Symmetric + x = DummyBigInt[1 2; 2 3] + y = LinearAlgebra.Hermitian(x) + @test MA.isequal_canonical( + LinearAlgebra.dot(x, y), + MA.operate(LinearAlgebra.dot, x, y), + ) + a = @allocated LinearAlgebra.dot(x, y) + b = @allocated MA.operate(LinearAlgebra.dot, x, y) + @test a == b + # AbstractArray + x = DummyBigInt[1 2; 2 3] + y = x + @test MA.isequal_canonical( + LinearAlgebra.dot(x, y), + MA.operate(LinearAlgebra.dot, x, y), + ) + a = @allocated LinearAlgebra.dot(x, y) + b = @allocated MA.operate(LinearAlgebra.dot, x, y) + @test a == b +end