From 2481775ea0adcd22689f979e3a0644f89d6df4db Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 25 Mar 2024 10:36:40 +1300 Subject: [PATCH 1/6] Fix mutability check for array multiplication --- src/interface.jl | 33 +++++++++++++++++++++++++++++++++ test/interface.jl | 19 +++++++++++++++++++ test/matmul.jl | 13 +++++++++++++ 3 files changed, 65 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index 74f9490..eaddbfa 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -281,6 +281,39 @@ function mutability(x, op, args::Vararg{Any,N}) where {N} return mutability(typeof(x), op, typeof.(args)...) end +# As a special case, we need to check that the shapes of matrix multiplication +# match in order for the array to be mutable. +function mutability( + x::AbstractArray, + op::typeof(*), + args::Vararg{Any,N}, +) where {N} + if mutability(typeof(x), op, typeof.(args)...) == IsNotMutable() + return IsNotMutable() + elseif size(x) == _size_after_multiply(size.(args)...) + return IsMutable() + else + return IsNotMutable() + end +end + +function _size_after_multiply(x::NTuple{M,Int}, y::NTuple{N,Int}) where {N,M} + if x[end] != y[1] + return nothing + end + return (x[1:end-1]..., y[2:end]...) +end + +_size_after_multiply(::Tuple{}, rhs::NTuple{M,Int}) where {M} = rhs +_size_after_multiply(lhs::NTuple{M,Int}, ::Tuple{}) where {M} = lhs +_size_after_multiply(::Tuple{}, ::Tuple{}) = () +_size_after_multiply(::Nothing, ::Any) = nothing + +function _size_after_multiply(x::NTuple{M,Int}, y::Vararg{Any,N}) where {N,M} + head = _size_after_multiply(x, Base.first(y)) + return _size_after_multiply(head, Base.tail(y)...) +end + mutability(::Type) = IsNotMutable() """ diff --git a/test/interface.jl b/test/interface.jl index aed08ec..bf75d10 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -187,3 +187,22 @@ end end end end + +@testset "issue_271_mutability" begin + a = 1 + x = [1; 2;;] + y = [1 2; 3 4] + z = [1 2; 3 4; 5 6] + @test MA.mutability(x, *, x, x') == MA.IsNotMutable() + @test MA.mutability(x, *, x', x) == MA.IsNotMutable() + @test MA.mutability(x, *, x, a, x') == MA.IsNotMutable() + @test MA.mutability(x, *, x', a, x) == MA.IsNotMutable() + @test MA.mutability(y, *, y, y) == MA.IsMutable() + @test MA.mutability(y, *, y, y, y) == MA.IsMutable() + @test MA.mutability(y, *, y, a, y') == MA.IsMutable() + @test MA.mutability(y, *, y', a, y) == MA.IsMutable() + @test MA.mutability(y, *, a, a, y) == MA.IsMutable() + @test MA.mutability(y, *, y, z', z) == MA.IsMutable() + @test MA.mutability(z, *, z, z) == MA.IsNotMutable() + @test MA.mutability(z, *, z, z, y) == MA.IsNotMutable() +end diff --git a/test/matmul.jl b/test/matmul.jl index 339b87e..9a27663 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -422,3 +422,16 @@ Base.:*(m::Monomial, ::Monomial) = m @test T == typeof(MA.operate(*, a, b)) end end + +@testset "Issue_271" begin + A = [1; 2;;] + B = [1 2] + C = MA.operate!!(*, A, B) + @test A == [1; 2;;] + @test B == [1 2] + @test C == AB + D = MA.operate!!(*, B, A) + @test A == [1; 2;;] + @test B == [1 2] + @test D == BA +end From 55cc9a95479882385a3eeda89fad68ec6281eaa9 Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 25 Mar 2024 10:59:45 +1300 Subject: [PATCH 2/6] Update --- test/matmul.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/matmul.jl b/test/matmul.jl index 9a27663..92cf061 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -142,11 +142,11 @@ end @test MA.mul(A, x) == BigInt[3; 3; 3] @test MA.mul_to!!(y, A, x) == BigInt[3; 3; 3] && y == BigInt[3; 3; 3] @test_throws DimensionMismatch MA.mul(BigInt[1 1; 1 1], BigInt[]) - @test_throws DimensionMismatch MA.mul_to!!( - BigInt[], - BigInt[1 1; 1 1], - BigInt[1; 1], - ) + @test MA.mul_to!!(BigInt[], BigInt[1 1; 1 1], BigInt[1; 1]) == + BigInt[2, 2] + z = BigInt[0, 0] + @test MA.mul_to!!(z, BigInt[1 1; 1 1], BigInt[1; 1]) === z + @test z == BigInt[2, 2] @testset "mutability" begin alloc_test(() -> MA.promote_operation(*, typeof(A), typeof(x)), 0) From 046aba9afbf5534854d2efb7ee80337ccbc659bd Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 25 Mar 2024 11:04:29 +1300 Subject: [PATCH 3/6] Fix --- test/matmul.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/matmul.jl b/test/matmul.jl index 92cf061..86eb619 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -429,9 +429,9 @@ end C = MA.operate!!(*, A, B) @test A == [1; 2;;] @test B == [1 2] - @test C == AB + @test C == A * B D = MA.operate!!(*, B, A) @test A == [1; 2;;] @test B == [1 2] - @test D == BA + @test D == B * A end From 07da160de66426853e53984371f7980a66cdab58 Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 25 Mar 2024 11:12:48 +1300 Subject: [PATCH 4/6] Update --- test/matmul.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/matmul.jl b/test/matmul.jl index 86eb619..4993264 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -219,11 +219,11 @@ end BigInt[1 1; 1 1], zeros(BigInt, 1, 1), ) - @test_throws DimensionMismatch MA.mul_to!!( + @test MA.mul_to!!( zeros(BigInt, 1, 1), BigInt[1 1; 1 1], zeros(BigInt, 2, 1), - ) + ) == zeros(BigInt, 2, 1) @testset "mutability" begin alloc_test(() -> MA.promote_operation(*, typeof(A), typeof(B)), 0) From 62c8f4d96f7740356d8f6f3bbff2fd783e8dc687 Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 25 Mar 2024 11:37:16 +1300 Subject: [PATCH 5/6] Update --- src/interface.jl | 8 +++----- test/matmul.jl | 6 +++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index eaddbfa..c515949 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -288,13 +288,11 @@ function mutability( op::typeof(*), args::Vararg{Any,N}, ) where {N} - if mutability(typeof(x), op, typeof.(args)...) == IsNotMutable() - return IsNotMutable() - elseif size(x) == _size_after_multiply(size.(args)...) + is_mutable = mutability(typeof(x), op, typeof.(args)...) + if is_mutable && size(x) == _size_after_multiply(size.(args)...) return IsMutable() - else - return IsNotMutable() end + return IsNotMutable() end function _size_after_multiply(x::NTuple{M,Int}, y::NTuple{N,Int}) where {N,M} diff --git a/test/matmul.jl b/test/matmul.jl index 4993264..029edcc 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -424,14 +424,14 @@ Base.:*(m::Monomial, ::Monomial) = m end @testset "Issue_271" begin - A = [1; 2;;] + A = reshape([1, 2], (2, 1)) B = [1 2] C = MA.operate!!(*, A, B) - @test A == [1; 2;;] + @test A == reshape([1, 2], (2, 1)) @test B == [1 2] @test C == A * B D = MA.operate!!(*, B, A) - @test A == [1; 2;;] + @test A == reshape([1, 2], (2, 1)) @test B == [1 2] @test D == B * A end From f00bd50121827cb70b034c83cfd5095a28d1d319 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Mon, 25 Mar 2024 12:38:09 +1300 Subject: [PATCH 6/6] Update src/interface.jl --- src/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index c515949..33dd132 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -288,7 +288,7 @@ function mutability( op::typeof(*), args::Vararg{Any,N}, ) where {N} - is_mutable = mutability(typeof(x), op, typeof.(args)...) + is_mutable = mutability(typeof(x), op, typeof.(args)...) == IsMutable() if is_mutable && size(x) == _size_after_multiply(size.(args)...) return IsMutable() end