Skip to content

Commit

Permalink
Fix mutability check for array multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Mar 24, 2024
1 parent 5e901e1 commit 6928f20
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 0 deletions.
33 changes: 33 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,39 @@ function mutability(x, op, args::Vararg{Any,N}) where {N}
return mutability(typeof(x), op, typeof.(args)...)
end

# As a special case, we need to check that the shapes of matrix multiplication
# match in order for the array to be mutable.
function mutability(
x::AbstractArray,
op::typeof(*),
args::Vararg{Any,N},
) where {N}
if mutability(typeof(x), op, typeof.(args)...) == IsNotMutable()
return IsNotMutable()
elseif size(x) == _size_after_multiply(size.(args)...)
return IsMutable()
else
return IsNotMutable()
end
end

function _size_after_multiply(x::NTuple{M,Int}, y::NTuple{N,Int}) where {N,M}
if x[end] != y[1]
return nothing
end
return (x[1:end-1]..., y[2:end]...)
end

_size_after_multiply(::Tuple{}, rhs::NTuple{M,Int}) where {M} = rhs
_size_after_multiply(lhs::NTuple{M,Int}, ::Tuple{}) where {M} = lhs
_size_after_multiply(::Tuple{}, ::Tuple{}) = ()
_size_after_multiply(::Nothing, ::Any) = nothing

function _size_after_multiply(x::NTuple{M,Int}, y::Vararg{Any,N}) where {N,M}
head = _size_after_multiply(x, Base.first(y))
return _size_after_multiply(head, Base.tail(y)...)
end

mutability(::Type) = IsNotMutable()

"""
Expand Down
19 changes: 19 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,22 @@ end
@test MA.operate_to!!(T(6), abs, T(7)) == 7
@test MA.operate_to!!(T(6), abs, T(-7)) == 7
end

@testset "issue_271_mutability" begin
a = 1
x = [1; 2;;]
y = [1 2; 3 4]
z = [1 2; 3 4; 5 6]
@test MA.mutability(x, *, x, x') == MA.IsNotMutable()
@test MA.mutability(x, *, x', x) == MA.IsNotMutable()
@test MA.mutability(x, *, x, a, x') == MA.IsNotMutable()
@test MA.mutability(x, *, x', a, x) == MA.IsNotMutable()
@test MA.mutability(y, *, y, y) == MA.IsMutable()
@test MA.mutability(y, *, y, y, y) == MA.IsMutable()
@test MA.mutability(y, *, y, a, y') == MA.IsMutable()
@test MA.mutability(y, *, y', a, y) == MA.IsMutable()
@test MA.mutability(y, *, a, a, y) == MA.IsMutable()
@test MA.mutability(y, *, y, z', z) == MA.IsMutable()
@test MA.mutability(z, *, z, z) == MA.IsNotMutable()
@test MA.mutability(z, *, z, z, y) == MA.IsNotMutable()
end
13 changes: 13 additions & 0 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,16 @@ Base.:*(m::Monomial, ::Monomial) = m
@test T == typeof(MA.operate(*, a, b))
end
end

@testset "Issue_271" begin
A = [1; 2;;]
B = [1 2]
C = MA.operate!!(*, A, B)
@test A == [1; 2;;]
@test B == [1 2]
@test C == AB
D = MA.operate!!(*, B, A)
@test A == [1; 2;;]
@test B == [1 2]
@test D == BA
end

0 comments on commit 6928f20

Please sign in to comment.