From 9ab66f20028ddd58a7f32b8fad06acd735b95785 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Tue, 7 Jun 2022 08:53:39 +1200 Subject: [PATCH] Fix broadcast for mis-matched arrays (#159) --- src/broadcast.jl | 25 +++++++++++++++++++++++++ test/broadcast.jl | 11 +++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/broadcast.jl b/src/broadcast.jl index 107f1d45..1cfd307f 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -89,6 +89,31 @@ function broadcast_mutability(x, op, args::Vararg{Any,N}) where {N} return broadcast_mutability(typeof(x), op, typeof.(args)...) end +_checked_size(s, x::AbstractArray) = size(x) == s +_checked_size(::Any, ::Any) = true +_checked_size(::Any, ::Tuple{}) = true +function _checked_size(s, x::Tuple) + return _checked_size(s, x[1]) && _checked_size(s, Base.tail(x)) +end + +# This method is a slightly tricky one: +# +# If the elements in the broadcast are different sized arrays, weird things can +# happen during broadcasting since we'll either need to return a different size +# to `x`, or multiple copies of an argument will be used for different parts of +# `x`. To simplify, let's just return `IsNotMutable` if the sizes are different, +# which will be slower but correct. +function broadcast_mutability( + x::AbstractArray, + op, + args::Vararg{Any,N}, +) where {N} + if !_checked_size(size(x), args)::Bool + return IsNotMutable() + end + return broadcast_mutability(typeof(x), op, typeof.(args)...) +end + broadcast_mutability(::Type) = IsNotMutable() """ diff --git a/test/broadcast.jl b/test/broadcast.jl index 372cfb93..e660967d 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -19,6 +19,7 @@ const MA = MutableArithmetics alloc_test(() -> MA.broadcast!!(+, a, b), 0) alloc_test(() -> MA.broadcast!!(+, a, c), 0) end + @testset "BigInt" begin x = BigInt(1) y = BigInt(2) @@ -34,3 +35,13 @@ end alloc_test(() -> MA.broadcast!!(+, a, b), 30 * sizeof(Int)) alloc_test(() -> MA.broadcast!!(+, a, c), 0) end + +@testset "broadcast_issue_158" begin + x, y = BigInt[2 3], BigInt[2 3; 3 4] + @test MA.@rewrite(x .+ y) == x .+ y + @test MA.@rewrite(x .- y) == x .- y + @test MA.@rewrite(y .+ x) == y .+ x + @test MA.@rewrite(y .- x) == y .- x + @test MA.@rewrite(y .* x) == y .* x + @test MA.@rewrite(x .* y) == x .* y +end