Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite of sum()*constant is suboptimal #169

Closed
odow opened this issue Oct 13, 2022 · 5 comments · Fixed by #170
Closed

Rewrite of sum()*constant is suboptimal #169

odow opened this issue Oct 13, 2022 · 5 comments · Fixed by #170

Comments

@odow
Copy link
Member

odow commented Oct 13, 2022

Consider:

julia> import MutableArithmetics

julia> @macroexpand MutableArithmetics.@rewrite(sum(i for i in 1:2) * 2)
quote
    #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:293 =#
    let
        #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:294 =#
        begin
            var"#892###928" = (MutableArithmetics.MutableArithmetics).Zero()
            begin
                for i = 1:2
                    var"#892###928" = (MutableArithmetics.MutableArithmetics).operate!!((MutableArithmetics.MutableArithmetics).add_mul, var"#892###928", i, 2)
                end
                var"#893###927" = var"#892###928"
            end
        end
        #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:295 =#
        var"#893###927"
    end
end

julia> @macroexpand MutableArithmetics.@rewrite(sum(i for i in 1:2) / 2)
quote
    #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:293 =#
    let
        #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:294 =#
        begin
            var"#896###930" = (MutableArithmetics.MutableArithmetics).Zero()
            begin
                for i = 1:2
                    var"#896###930" = (MutableArithmetics.MutableArithmetics).operate!!((MutableArithmetics.MutableArithmetics).add_mul, var"#896###930", i, 1 / 2)
                end
                var"#897###929" = var"#896###930"
            end
        end
        #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:295 =#
        var"#897###929"
    end
end

julia> @macroexpand MutableArithmetics.@rewrite(2 * sum(i for i in 1:2))
quote
    #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:293 =#
    let
        #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:294 =#
        begin
            var"#908###936" = (MutableArithmetics.MutableArithmetics).Zero()
            begin
                for i = 1:2
                    var"#908###936" = (MutableArithmetics.MutableArithmetics).operate!!((MutableArithmetics.MutableArithmetics).add_mul, var"#908###936", 2, i)
                end
                var"#909###935" = var"#908###936"
            end
        end
        #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:295 =#
        var"#909###935"
    end
end

julia> @macroexpand MutableArithmetics.@rewrite(2 / sum(i for i in 1:2))
quote
    #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:293 =#
    let
        #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:294 =#
        begin
            var"#912###938" = (MutableArithmetics.MutableArithmetics).Zero()
            var"#913###937" = (MutableArithmetics.MutableArithmetics).operate!!((MutableArithmetics.MutableArithmetics).add_mul, var"#912###938", 2, 1 / sum((i for i = 1:2)))
        end
        #= /Users/oscar/.julia/packages/MutableArithmetics/maUDe/src/rewrite.jl:295 =#
        var"#913###937"
    end
end

All cases are troubling:

  • The 1st and 3rd examples leads to n additional * operations
  • The 2nd example leads to n additional / operations
  • The 4th doesn't even use mutation.

This came up digging into this example in jump-dev/JuMP.jl#3106:

@objective(
    model,
    Max,
    n / 2 * log(1 / (2 * π * σ^2)) -
    sum((data[i] - μ)^2 for i in 1:n) / (2 * σ^2)
)

The first step brigs the denominator into the loop:

@objective(
    model,
    Max,
    n / 2 * log(1 / (2 * π * σ^2)) -
    sum((data[i] - μ)^2 * (1 / (2 * σ^2)) for i in 1:n)
)

but then, even worse, it brings the - into the loop:

term = n / 2 * log(1 / (2 * π * σ^2))
for i in 1:n
    sub_mul(term, (data[i] - μ)^2, 1 / (2 * σ^2))
end
@objective(model, Max, term)

a better outcome would be something along the lines of

term = n / 2 * log(1 / (2 * π * σ^2))
sum_term = 0
for i in 1:n
    add_mul(term, (data[i] - μ)^2)
end
obj = operate(-, term, operate(/, sub_term, 2 * σ^2))
@objective(model, Max, obj)
@odow
Copy link
Member Author

odow commented Oct 13, 2022

@blegat, what was the motivation for the current design? Why not just rewrite inner loops only?

@blegat
Copy link
Member

blegat commented Oct 13, 2022

JuMP has efficient add_mul operation, e.g., operate!(add_mul, ::AffExpr, ::Float64, VariableRef) is efficient so if you do sum(x[i] for i in eachindex(x)) * 2 where x is a vector of variables, you'd rather directly add the terms with the coefficients

@odow
Copy link
Member Author

odow commented Oct 13, 2022

I guess the design makes sense for linear/quadratic terms, but it really screws up the expression graph of nonlinear terms.

@odow
Copy link
Member Author

odow commented Oct 13, 2022

JuMP also has an efficient operate!(*, ::AffExpr, ::Float64) method though.

@odow
Copy link
Member Author

odow commented Oct 27, 2022

@blegat asks to benchmark sum(x[i] for i in 1:n) - sum(x[i]^2 for i in 1:n)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging a pull request may close this issue.

2 participants