Skip to content

Commit

Permalink
Implement performance optimization of promote_operation for *(::Any, …
Browse files Browse the repository at this point in the history
…::Zero)
  • Loading branch information
odow committed Apr 24, 2024
1 parent a6ed0f5 commit 49b874d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ function Base.:/(z::Zero, x::Any)
end
end

# This method is used to provide an efficient implementation for the common
# case like `x^2 * sum(f for i in 1:0)`, which lowers to
# `_MA.operate!!(*, x^2, _MA.Zero())`. We don't need the method with reversed
# arguments because MA.Zero is not mutable, and MA never queries the mutablility
# of arguments if the first is not mutable.
promote_operation(::typeof(*), ::Type{<:Any}, ::Type{Zero}) = Zero

# Needed by `@rewrite(BigInt(1) .+ sum(1 for i in 1:0) * 1^2)`
# since we don't require mutable type to support Zero in
# `mutable_operate!`.
Expand Down
59 changes: 59 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,62 @@ end
b = @allocated MA.operate(LinearAlgebra.dot, x, y)
@test a == b
end

@testset "test_multiply_expr_MA_Zero" begin
x = DummyBigInt(1)
f = DummyBigInt(2)
@test MA.@rewrite(
f * sum(x for i in 1:0),
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
sum(x for i in 1:0) * f,
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
-f * sum(x for i in 1:0),
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
sum(x for i in 1:0) * -f,
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
(f + f) * sum(x for i in 1:0),
move_factors_into_sums = false
) == MA.Zero()
@test MA.@rewrite(
sum(x for i in 1:0) * (f + f),
move_factors_into_sums = false
) == MA.Zero()
@test MA.isequal_canonical(
MA.@rewrite(f + sum(x for i in 1:0), move_factors_into_sums = false),
f,
)
@test MA.isequal_canonical(
MA.@rewrite(sum(x for i in 1:0) + f, move_factors_into_sums = false),
f,
)
@test MA.isequal_canonical(
MA.@rewrite(-f + sum(x for i in 1:0), move_factors_into_sums = false),
-f,
)
@test MA.isequal_canonical(
MA.@rewrite(sum(x for i in 1:0) + -f, move_factors_into_sums = false),
-f,
)
@test MA.isequal_canonical(
MA.@rewrite(
(f + f) + sum(x for i in 1:0),
move_factors_into_sums = false
),
f + f,
)
@test MA.isequal_canonical(
MA.@rewrite(
sum(x for i in 1:0) + (f + f),
move_factors_into_sums = false
),
f + f,
)
end

0 comments on commit 49b874d

Please sign in to comment.