diff --git a/src/rewrite.jl b/src/rewrite.jl index 44d43a7..811570f 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -79,6 +79,21 @@ function Base.:/(z::Zero, x::Any) end end +# These methods are used to provide an efficient implementation for the common +# case like `x^2 * sum(f for i in 1:0)`, which lowers to +# `_MA.operate!!(*, x^2, _MA.Zero())`. We don't need the method with reversed +# arguments because MA.Zero is not mutable, and MA never queries the mutablility +# of arguments if the first is not mutable. +promote_operation(::typeof(*), ::Type{<:Any}, ::Type{Zero}) = Zero + +function promote_operation( + ::typeof(*), + ::Type{<:AbstractArray{T}}, + ::Type{Zero}, +) where {T} + return Zero +end + # Needed by `@rewrite(BigInt(1) .+ sum(1 for i in 1:0) * 1^2)` # since we don't require mutable type to support Zero in # `mutable_operate!`. diff --git a/test/rewrite.jl b/test/rewrite.jl index e2d1c8a..840174b 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -199,3 +199,70 @@ end b = @allocated MA.operate(LinearAlgebra.dot, x, y) @test a == b end + +@testset "test_multiply_expr_MA_Zero" begin + x = DummyBigInt(1) + f = DummyBigInt(2) + @test MA.@rewrite( + f * sum(x for i in 1:0), + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + sum(x for i in 1:0) * f, + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + -f * sum(x for i in 1:0), + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + sum(x for i in 1:0) * -f, + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + (f + f) * sum(x for i in 1:0), + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + sum(x for i in 1:0) * (f + f), + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + -[f] * sum(x for i in 1:0), + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + sum(x for i in 1:0) * -[f], + move_factors_into_sums = false + ) == MA.Zero() + @test MA.isequal_canonical( + MA.@rewrite(f + sum(x for i in 1:0), move_factors_into_sums = false), + f, + ) + @test MA.isequal_canonical( + MA.@rewrite(sum(x for i in 1:0) + f, move_factors_into_sums = false), + f, + ) + @test MA.isequal_canonical( + MA.@rewrite(-f + sum(x for i in 1:0), move_factors_into_sums = false), + -f, + ) + @test MA.isequal_canonical( + MA.@rewrite(sum(x for i in 1:0) + -f, move_factors_into_sums = false), + -f, + ) + @test MA.isequal_canonical( + MA.@rewrite( + (f + f) + sum(x for i in 1:0), + move_factors_into_sums = false + ), + f + f, + ) + @test MA.isequal_canonical( + MA.@rewrite( + sum(x for i in 1:0) + (f + f), + move_factors_into_sums = false + ), + f + f, + ) +end