Skip to content

Commit

Permalink
Fix size check of arrays when broadcasting (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Aug 1, 2024
1 parent 0e42a16 commit 1e73552
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
27 changes: 20 additions & 7 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,28 +102,41 @@ function broadcast_mutability(x, op, args::Vararg{Any,N}) where {N}
return broadcast_mutability(typeof(x), op, typeof.(args)...)
end

_checked_size(s, x::AbstractArray) = length(x) == s
# Some AbstractArray, like JuMP.Containers.SparseAxisArray, do not support
# Base.size. In such cases, we default to returning `false`, since we cannot
# safely decide whether a broadcast can be stored in `x` unless we know the
# sizes of the two entries.
function _try_size(x::AbstractArray)
try
return size(x)
catch
return missing
end
end
_try_size(x::Array) = size(x)

_checked_size(x_size::Any, y::AbstractArray) = x_size == _try_size(y)
_checked_size(::Any, ::Any) = true
_checked_size(::Any, ::Tuple{}) = true
function _checked_size(s, x::Tuple)
return _checked_size(s, x[1]) && _checked_size(s, Base.tail(x))
function _checked_size(x_size::Any, y::Tuple)
return _checked_size(x_size, y[1]) && _checked_size(x_size, Base.tail(y))
end
_checked_size(::Missing, ::Tuple) = false
_checked_size(::Missing, ::Tuple{}) = false

# This method is a slightly tricky one:
#
# If the elements in the broadcast are different sized arrays, weird things can
# happen during broadcasting since we'll either need to return a different size
# to `x`, or multiple copies of an argument will be used for different parts of
# `x`. To simplify, let's just return `IsNotMutable` if the sizes are different,
# which will be slower but correct. This is slightly complicated by the fact
# that some AbstractArray do not support `size`, so we check with `length`
# instead. If the `size`s are different, a later error will be thrown.
# which will be slower but correct.
function broadcast_mutability(
x::AbstractArray,
op,
args::Vararg{Any,N},
) where {N}
if !_checked_size(length(x), args)::Bool
if !_checked_size(_try_size(x), args)::Bool
return IsNotMutable()
end
return broadcast_mutability(typeof(x), op, typeof.(args)...)
Expand Down
7 changes: 7 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,10 @@ Base.BroadcastStyle(::Type{Struct221}) = BroadcastStyle221()
@testset "promote_broadcast_for_new_style" begin
@test MA.promote_broadcast(MA.add_mul, Vector{Int}, Struct221) === Any
end

@testset "broadcast_length_1_dimensions" begin
A = rand(2, 1, 3)
B = rand(2, 3)
@test MA.broadcast!!(MA.sub_mul, A, B) A .- B
@test MA.broadcast!!(MA.sub_mul, B, A) B .- A
end

0 comments on commit 1e73552

Please sign in to comment.