diff --git a/src/rewrite.jl b/src/rewrite.jl index 8590ebe..4b3b338 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -59,7 +59,9 @@ broadcast!!(::typeof(add_mul), ::Zero, x, y) = x * y # Needed in `@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)` Base.:*(z::Zero, ::Any) = z +Base.:*(z::Zero, ::Number) = z Base.:*(::Any, z::Zero) = z +Base.:*(::Number, z::Zero) = z Base.:*(z::Zero, ::Zero) = z Base.:+(::Zero, x::Any) = x Base.:+(::Zero, x::Number) = x @@ -67,7 +69,9 @@ Base.:+(x::Any, ::Zero) = x Base.:+(x::Number, ::Zero) = x Base.:+(z::Zero, ::Zero) = z Base.:-(::Zero, x::Any) = -x +Base.:-(::Zero, x::Number) = -x Base.:-(x::Any, ::Zero) = x +Base.:-(x::Number, ::Zero) = x Base.:-(z::Zero, ::Zero) = z Base.:-(z::Zero) = z Base.:+(z::Zero) = z @@ -81,6 +85,14 @@ function Base.:/(z::Zero, x::Any) end end +function Base.:/(z::Zero, x::Number) + if iszero(x) + throw(DivideError()) + else + return z + end +end + # These methods are used to provide an efficient implementation for the common # case like `x^2 * sum(f for i in 1:0)`, which lowers to # `_MA.operate!!(*, x^2, _MA.Zero())`. We don't need the method with reversed