From 1e73552a9b7d157fa2bd9fbf147141970777227c Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Fri, 2 Aug 2024 11:44:13 +1200 Subject: [PATCH] Fix size check of arrays when broadcasting (#296) --- src/broadcast.jl | 27 ++++++++++++++++++++------- test/broadcast.jl | 7 +++++++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index 4b8d3761..2b65c0b6 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -102,12 +102,27 @@ function broadcast_mutability(x, op, args::Vararg{Any,N}) where {N} return broadcast_mutability(typeof(x), op, typeof.(args)...) end -_checked_size(s, x::AbstractArray) = length(x) == s +# Some AbstractArray, like JuMP.Containers.SparseAxisArray, do not support +# Base.size. In such cases, we default to returning `false`, since we cannot +# safely decide whether a broadcast can be stored in `x` unless we know the +# sizes of the two entries. +function _try_size(x::AbstractArray) + try + return size(x) + catch + return missing + end +end +_try_size(x::Array) = size(x) + +_checked_size(x_size::Any, y::AbstractArray) = x_size == _try_size(y) _checked_size(::Any, ::Any) = true _checked_size(::Any, ::Tuple{}) = true -function _checked_size(s, x::Tuple) - return _checked_size(s, x[1]) && _checked_size(s, Base.tail(x)) +function _checked_size(x_size::Any, y::Tuple) + return _checked_size(x_size, y[1]) && _checked_size(x_size, Base.tail(y)) end +_checked_size(::Missing, ::Tuple) = false +_checked_size(::Missing, ::Tuple{}) = false # This method is a slightly tricky one: # @@ -115,15 +130,13 @@ end # happen during broadcasting since we'll either need to return a different size # to `x`, or multiple copies of an argument will be used for different parts of # `x`. To simplify, let's just return `IsNotMutable` if the sizes are different, -# which will be slower but correct. This is slightly complicated by the fact -# that some AbstractArray do not support `size`, so we check with `length` -# instead. If the `size`s are different, a later error will be thrown. +# which will be slower but correct. function broadcast_mutability( x::AbstractArray, op, args::Vararg{Any,N}, ) where {N} - if !_checked_size(length(x), args)::Bool + if !_checked_size(_try_size(x), args)::Bool return IsNotMutable() end return broadcast_mutability(typeof(x), op, typeof.(args)...) diff --git a/test/broadcast.jl b/test/broadcast.jl index edad933a..4124230b 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -52,3 +52,10 @@ Base.BroadcastStyle(::Type{Struct221}) = BroadcastStyle221() @testset "promote_broadcast_for_new_style" begin @test MA.promote_broadcast(MA.add_mul, Vector{Int}, Struct221) === Any end + +@testset "broadcast_length_1_dimensions" begin + A = rand(2, 1, 3) + B = rand(2, 3) + @test MA.broadcast!!(MA.sub_mul, A, B) ≈ A .- B + @test MA.broadcast!!(MA.sub_mul, B, A) ≈ B .- A +end