Skip to content

Commit

Permalink
Fix broadcast for mis-matched arrays (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Jun 6, 2022
1 parent 788364d commit 9ab66f2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,31 @@ function broadcast_mutability(x, op, args::Vararg{Any,N}) where {N}
return broadcast_mutability(typeof(x), op, typeof.(args)...)
end

_checked_size(s, x::AbstractArray) = size(x) == s
_checked_size(::Any, ::Any) = true
_checked_size(::Any, ::Tuple{}) = true
function _checked_size(s, x::Tuple)
return _checked_size(s, x[1]) && _checked_size(s, Base.tail(x))
end

# This method is a slightly tricky one:
#
# If the elements in the broadcast are different sized arrays, weird things can
# happen during broadcasting since we'll either need to return a different size
# to `x`, or multiple copies of an argument will be used for different parts of
# `x`. To simplify, let's just return `IsNotMutable` if the sizes are different,
# which will be slower but correct.
function broadcast_mutability(
x::AbstractArray,
op,
args::Vararg{Any,N},
) where {N}
if !_checked_size(size(x), args)::Bool
return IsNotMutable()
end
return broadcast_mutability(typeof(x), op, typeof.(args)...)
end

broadcast_mutability(::Type) = IsNotMutable()

"""
Expand Down
11 changes: 11 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const MA = MutableArithmetics
alloc_test(() -> MA.broadcast!!(+, a, b), 0)
alloc_test(() -> MA.broadcast!!(+, a, c), 0)
end

@testset "BigInt" begin
x = BigInt(1)
y = BigInt(2)
Expand All @@ -34,3 +35,13 @@ end
alloc_test(() -> MA.broadcast!!(+, a, b), 30 * sizeof(Int))
alloc_test(() -> MA.broadcast!!(+, a, c), 0)
end

@testset "broadcast_issue_158" begin
x, y = BigInt[2 3], BigInt[2 3; 3 4]
@test MA.@rewrite(x .+ y) == x .+ y
@test MA.@rewrite(x .- y) == x .- y
@test MA.@rewrite(y .+ x) == y .+ x
@test MA.@rewrite(y .- x) == y .- x
@test MA.@rewrite(y .* x) == y .* x
@test MA.@rewrite(x .* y) == x .* y
end

0 comments on commit 9ab66f2

Please sign in to comment.