From 49b874de661f819f2393465e0f6c916d84b3c4f1 Mon Sep 17 00:00:00 2001 From: odow Date: Thu, 25 Apr 2024 09:48:44 +1200 Subject: [PATCH 1/3] Implement performance optimization of promote_operation for *(::Any, ::Zero) --- src/rewrite.jl | 7 ++++++ test/rewrite.jl | 59 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/src/rewrite.jl b/src/rewrite.jl index 44d43a7..0f32127 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -79,6 +79,13 @@ function Base.:/(z::Zero, x::Any) end end +# This method is used to provide an efficient implementation for the common +# case like `x^2 * sum(f for i in 1:0)`, which lowers to +# `_MA.operate!!(*, x^2, _MA.Zero())`. We don't need the method with reversed +# arguments because MA.Zero is not mutable, and MA never queries the mutablility +# of arguments if the first is not mutable. +promote_operation(::typeof(*), ::Type{<:Any}, ::Type{Zero}) = Zero + # Needed by `@rewrite(BigInt(1) .+ sum(1 for i in 1:0) * 1^2)` # since we don't require mutable type to support Zero in # `mutable_operate!`. diff --git a/test/rewrite.jl b/test/rewrite.jl index e2d1c8a..6f38548 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -199,3 +199,62 @@ end b = @allocated MA.operate(LinearAlgebra.dot, x, y) @test a == b end + +@testset "test_multiply_expr_MA_Zero" begin + x = DummyBigInt(1) + f = DummyBigInt(2) + @test MA.@rewrite( + f * sum(x for i in 1:0), + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + sum(x for i in 1:0) * f, + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + -f * sum(x for i in 1:0), + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + sum(x for i in 1:0) * -f, + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + (f + f) * sum(x for i in 1:0), + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + sum(x for i in 1:0) * (f + f), + move_factors_into_sums = false + ) == MA.Zero() + @test MA.isequal_canonical( + MA.@rewrite(f + sum(x for i in 1:0), move_factors_into_sums = false), + f, + ) + @test MA.isequal_canonical( + MA.@rewrite(sum(x for i in 1:0) + f, move_factors_into_sums = false), + f, + ) + @test MA.isequal_canonical( + MA.@rewrite(-f + sum(x for i in 1:0), move_factors_into_sums = false), + -f, + ) + @test MA.isequal_canonical( + MA.@rewrite(sum(x for i in 1:0) + -f, move_factors_into_sums = false), + -f, + ) + @test MA.isequal_canonical( + MA.@rewrite( + (f + f) + sum(x for i in 1:0), + move_factors_into_sums = false + ), + f + f, + ) + @test MA.isequal_canonical( + MA.@rewrite( + sum(x for i in 1:0) + (f + f), + move_factors_into_sums = false + ), + f + f, + ) +end From d58c1678e2c63855abd3e2cdd4cfda23e4973423 Mon Sep 17 00:00:00 2001 From: odow Date: Thu, 25 Apr 2024 10:39:16 +1200 Subject: [PATCH 2/3] Update --- src/rewrite.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/rewrite.jl b/src/rewrite.jl index 0f32127..6d6dc0b 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -85,6 +85,7 @@ end # arguments because MA.Zero is not mutable, and MA never queries the mutablility # of arguments if the first is not mutable. promote_operation(::typeof(*), ::Type{<:Any}, ::Type{Zero}) = Zero +promote_operation(::typeof(*), ::Type{<:AbstractArray}, ::Type{Zero}) = Zero # Needed by `@rewrite(BigInt(1) .+ sum(1 for i in 1:0) * 1^2)` # since we don't require mutable type to support Zero in From c7571c7c9aa7eaecc729b6ebe9b3f52224324217 Mon Sep 17 00:00:00 2001 From: odow Date: Thu, 25 Apr 2024 10:58:03 +1200 Subject: [PATCH 3/3] Update --- src/rewrite.jl | 11 +++++++++-- test/rewrite.jl | 8 ++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/rewrite.jl b/src/rewrite.jl index 6d6dc0b..811570f 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -79,13 +79,20 @@ function Base.:/(z::Zero, x::Any) end end -# This method is used to provide an efficient implementation for the common +# These methods are used to provide an efficient implementation for the common # case like `x^2 * sum(f for i in 1:0)`, which lowers to # `_MA.operate!!(*, x^2, _MA.Zero())`. We don't need the method with reversed # arguments because MA.Zero is not mutable, and MA never queries the mutablility # of arguments if the first is not mutable. promote_operation(::typeof(*), ::Type{<:Any}, ::Type{Zero}) = Zero -promote_operation(::typeof(*), ::Type{<:AbstractArray}, ::Type{Zero}) = Zero + +function promote_operation( + ::typeof(*), + ::Type{<:AbstractArray{T}}, + ::Type{Zero}, +) where {T} + return Zero +end # Needed by `@rewrite(BigInt(1) .+ sum(1 for i in 1:0) * 1^2)` # since we don't require mutable type to support Zero in diff --git a/test/rewrite.jl b/test/rewrite.jl index 6f38548..840174b 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -227,6 +227,14 @@ end sum(x for i in 1:0) * (f + f), move_factors_into_sums = false ) == MA.Zero() + @test MA.@rewrite( + -[f] * sum(x for i in 1:0), + move_factors_into_sums = false + ) == MA.Zero() + @test MA.@rewrite( + sum(x for i in 1:0) * -[f], + move_factors_into_sums = false + ) == MA.Zero() @test MA.isequal_canonical( MA.@rewrite(f + sum(x for i in 1:0), move_factors_into_sums = false), f,