From 739bcc52d7c9a1a83dda8088f8df7af41e222006 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Fri, 21 Feb 2020 12:00:40 +0100 Subject: [PATCH] Fix promotion with adjoint and transpose --- src/Test/array.jl | 40 +++++++++++++++++++++++++++------------- src/linear_algebra.jl | 7 +++++++ test/dummy.jl | 4 ++-- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/Test/array.jl b/src/Test/array.jl index f3875b47..7ac60f60 100644 --- a/src/Test/array.jl +++ b/src/Test/array.jl @@ -8,21 +8,35 @@ function matrix_vector_division_test(x) end function _xAx_test(x::AbstractVector, A::AbstractMatrix) - @test_rewrite(x' * A) - # Complex expression - @test_rewrite(x' * ones(Int, size(A)...)) - @test_rewrite(x' * A * x) - # Complex expression - @test_rewrite(x' * ones(Int, size(A)...) * x) - @test_rewrite reshape(x, (1, length(x))) * A * x .- 1 - @test_rewrite x' * A * x .- 1 - @test_rewrite x' * A * x - 1 + for t in [transpose, adjoint] + @test_rewrite(t(x) * A) + # Complex expression + @test_rewrite(t(x) * ones(Int, size(A)...)) + @test_rewrite(t(x) * A * x) + # Complex expression + @test_rewrite(t(x) * ones(Int, size(A)...) * x) + @test_rewrite reshape(x, (1, length(x))) * A * x .- 1 + @test_rewrite t(x) * A * x .- 1 + @test_rewrite t(x) * A * x - 1 + @test_rewrite t(x) * x + t(x) * A * x + @test_rewrite t(x) * x - t(x) * A * x + @test MA.promote_operation(*, typeof(t(x)), typeof(A), typeof(x)) == typeof(t(x) * A * x) + @test MA.promote_operation(*, typeof(t(x)), typeof(x)) == typeof(t(x) * x) + @test_rewrite t(x) * x + 2 * t(x) * A * x + @test_rewrite t(x) * x - 2 * t(x) * A * x + @test_rewrite t(x) * A * x + 2 * t(x) * x + @test_rewrite t(x) * A * x - 2 * t(x) * x + @test MA.promote_operation(*, Int, typeof(t(x)), typeof(A), typeof(x)) == typeof(2 * t(x) * A * x) + @test MA.promote_operation(*, Int, typeof(t(x)), typeof(x)) == typeof(2 * t(x) * x) + end end function _xABx_test(x::AbstractVector, A::AbstractMatrix, B::AbstractMatrix) - @test_rewrite (x'A)' + 2B * x - @test_rewrite (x'A)' + 2B * x .- 1 - @test_rewrite (x'A)' + 2B * x .- [length(x):-1:1;] - @test_rewrite (x'A)' + 2B * x - [length(x):-1:1;] + for t in [transpose, adjoint] + @test_rewrite t(t(x) * A) + 2B * x + @test_rewrite t(t(x) * A) + 2B * x .- 1 + @test_rewrite t(t(x) * A) + 2B * x .- [length(x):-1:1;] + @test_rewrite t(t(x) * A) + 2B * x - [length(x):-1:1;] + end end function _matrix_vector_test(x::AbstractVector, A::AbstractMatrix) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index f7f82506..f3a54e63 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -233,10 +233,17 @@ end const TransposeOrAdjoint{T, MT} = Union{LinearAlgebra.Transpose{T, MT}, LinearAlgebra.Adjoint{T, MT}} _mirror_transpose_or_adjoint(x, ::LinearAlgebra.Transpose) = LinearAlgebra.transpose(x) _mirror_transpose_or_adjoint(x, ::LinearAlgebra.Adjoint) = LinearAlgebra.adjoint(x) +_mirror_transpose_or_adjoint(A::Type{<:AbstractArray{T}}, ::Type{<:LinearAlgebra.Transpose}) where {T} = LinearAlgebra.Transpose{T, A} +_mirror_transpose_or_adjoint(A::Type{<:AbstractArray{T}}, ::Type{<:LinearAlgebra.Adjoint}) where {T} = LinearAlgebra.Adjoint{T, A} +similar_array_type(TA::Type{<:TransposeOrAdjoint{T, A}}, ::Type{S}) where {S, T, A} = _mirror_transpose_or_adjoint(similar_array_type(A, S), TA) # dot product function promote_array_mul(::Type{<:TransposeOrAdjoint{S, <:AbstractVector}}, ::Type{<:AbstractVector{T}}) where {S, T} return promote_sum_mul(S, T) end +function promote_array_mul(A::Type{<:TransposeOrAdjoint{S, V}}, M::Type{<:AbstractMatrix{T}}) where {S, T, V <: AbstractVector} + B = promote_array_mul(_mirror_transpose_or_adjoint(M, A), V) + return _mirror_transpose_or_adjoint(B, A) +end function operate(::typeof(*), x::LinearAlgebra.Adjoint{<:Any, <:AbstractVector}, y::AbstractVector) return operate(LinearAlgebra.dot, parent(x), y) end diff --git a/test/dummy.jl b/test/dummy.jl index be23c03c..2bd90a07 100644 --- a/test/dummy.jl +++ b/test/dummy.jl @@ -30,8 +30,8 @@ MA.scaling(x::DummyBigInt) = x MA.mutable_operate_to!(x::DummyBigInt, op::Function, args::Union{MA.Scaling, DummyBigInt}...) = DummyBigInt(MA.mutable_operate_to!(x.data, op, _data.(args)...)) # Called for instance if `args` is `(v', v)` for a vector `v`. -MA.mutable_operate_to!(output::DummyBigInt, op::typeof(MA.add_mul), x::Union{MA.Scaling, DummyBigInt}, y::Union{MA.Scaling, DummyBigInt}, z::Union{MA.Scaling, DummyBigInt}, args::Union{MA.Scaling, DummyBigInt}...) = MA.mutable_operate_to!(output, +, x, *(y, z, args...)) -MA.mutable_operate_to!(output::DummyBigInt, op::typeof(MA.add_mul), x, y, z, args...) = MA.mutable_operate_to!(output, +, x, *(y, z, args...)) +MA.mutable_operate_to!(output::DummyBigInt, op::MA.AddSubMul, x::Union{MA.Scaling, DummyBigInt}, y::Union{MA.Scaling, DummyBigInt}, z::Union{MA.Scaling, DummyBigInt}, args::Union{MA.Scaling, DummyBigInt}...) = MA.mutable_operate_to!(output, MA.add_sub_op(op), x, *(y, z, args...)) +MA.mutable_operate_to!(output::DummyBigInt, op::MA.AddSubMul, x, y, z, args...) = MA.mutable_operate_to!(output, MA.add_sub_op(op), x, *(y, z, args...)) function MA.mutable_operate!(op::Function, x::DummyBigInt, args::Vararg{Any, N}) where N MA.mutable_operate_to!(x, op, x, args...) end