From 69b1c51585aa859fc6b36d44e5ef4fec0edf1a6b Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 27 Mar 2024 08:52:25 +1300 Subject: [PATCH] Fix mutability check for array multiplication (#272) --- src/interface.jl | 31 +++++++++++++++++++++++++++++++ test/interface.jl | 19 +++++++++++++++++++ test/matmul.jl | 27 ++++++++++++++++++++------- 3 files changed, 70 insertions(+), 7 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 74f94903..33dd132c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -281,6 +281,37 @@ function mutability(x, op, args::Vararg{Any,N}) where {N} return mutability(typeof(x), op, typeof.(args)...) end +# As a special case, we need to check that the shapes of matrix multiplication +# match in order for the array to be mutable. +function mutability( + x::AbstractArray, + op::typeof(*), + args::Vararg{Any,N}, +) where {N} + is_mutable = mutability(typeof(x), op, typeof.(args)...) == IsMutable() + if is_mutable && size(x) == _size_after_multiply(size.(args)...) + return IsMutable() + end + return IsNotMutable() +end + +function _size_after_multiply(x::NTuple{M,Int}, y::NTuple{N,Int}) where {N,M} + if x[end] != y[1] + return nothing + end + return (x[1:end-1]..., y[2:end]...) +end + +_size_after_multiply(::Tuple{}, rhs::NTuple{M,Int}) where {M} = rhs +_size_after_multiply(lhs::NTuple{M,Int}, ::Tuple{}) where {M} = lhs +_size_after_multiply(::Tuple{}, ::Tuple{}) = () +_size_after_multiply(::Nothing, ::Any) = nothing + +function _size_after_multiply(x::NTuple{M,Int}, y::Vararg{Any,N}) where {N,M} + head = _size_after_multiply(x, Base.first(y)) + return _size_after_multiply(head, Base.tail(y)...) +end + mutability(::Type) = IsNotMutable() """ diff --git a/test/interface.jl b/test/interface.jl index aed08ecd..bf75d10f 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -187,3 +187,22 @@ end end end end + +@testset "issue_271_mutability" begin + a = 1 + x = [1; 2;;] + y = [1 2; 3 4] + z = [1 2; 3 4; 5 6] + @test MA.mutability(x, *, x, x') == MA.IsNotMutable() + @test MA.mutability(x, *, x', x) == MA.IsNotMutable() + @test MA.mutability(x, *, x, a, x') == MA.IsNotMutable() + @test MA.mutability(x, *, x', a, x) == MA.IsNotMutable() + @test MA.mutability(y, *, y, y) == MA.IsMutable() + @test MA.mutability(y, *, y, y, y) == MA.IsMutable() + @test MA.mutability(y, *, y, a, y') == MA.IsMutable() + @test MA.mutability(y, *, y', a, y) == MA.IsMutable() + @test MA.mutability(y, *, a, a, y) == MA.IsMutable() + @test MA.mutability(y, *, y, z', z) == MA.IsMutable() + @test MA.mutability(z, *, z, z) == MA.IsNotMutable() + @test MA.mutability(z, *, z, z, y) == MA.IsNotMutable() +end diff --git a/test/matmul.jl b/test/matmul.jl index 339b87e1..029edccb 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -142,11 +142,11 @@ end @test MA.mul(A, x) == BigInt[3; 3; 3] @test MA.mul_to!!(y, A, x) == BigInt[3; 3; 3] && y == BigInt[3; 3; 3] @test_throws DimensionMismatch MA.mul(BigInt[1 1; 1 1], BigInt[]) - @test_throws DimensionMismatch MA.mul_to!!( - BigInt[], - BigInt[1 1; 1 1], - BigInt[1; 1], - ) + @test MA.mul_to!!(BigInt[], BigInt[1 1; 1 1], BigInt[1; 1]) == + BigInt[2, 2] + z = BigInt[0, 0] + @test MA.mul_to!!(z, BigInt[1 1; 1 1], BigInt[1; 1]) === z + @test z == BigInt[2, 2] @testset "mutability" begin alloc_test(() -> MA.promote_operation(*, typeof(A), typeof(x)), 0) @@ -219,11 +219,11 @@ end BigInt[1 1; 1 1], zeros(BigInt, 1, 1), ) - @test_throws DimensionMismatch MA.mul_to!!( + @test MA.mul_to!!( zeros(BigInt, 1, 1), BigInt[1 1; 1 1], zeros(BigInt, 2, 1), - ) + ) == zeros(BigInt, 2, 1) @testset "mutability" begin alloc_test(() -> MA.promote_operation(*, typeof(A), typeof(B)), 0) @@ -422,3 +422,16 @@ Base.:*(m::Monomial, ::Monomial) = m @test T == typeof(MA.operate(*, a, b)) end end + +@testset "Issue_271" begin + A = reshape([1, 2], (2, 1)) + B = [1 2] + C = MA.operate!!(*, A, B) + @test A == reshape([1, 2], (2, 1)) + @test B == [1 2] + @test C == A * B + D = MA.operate!!(*, B, A) + @test A == reshape([1, 2], (2, 1)) + @test B == [1 2] + @test D == B * A +end