Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new rewrite that does not make strong assumptions on result type #170

Merged
merged 7 commits into from
Nov 21, 2022

Conversation

odow
Copy link
Member

@odow odow commented Oct 28, 2022

Closes #44
Closes #169

This branch is interesting to try out.

First, the current @rewrite produces different results to what we would obtain using the non-mutating API. I'm surprised this hasn't come up before.

Second, this PR has better performance than the current @rewrite, at least for the BigXXX types. The differences can be quite large (2x). Noticeably, just evaluating things is faster, likely because of type stability?

Benchmarks

Code
module BenchmarkMA

import MutableArithmetics

const MutableArithmetics2 = MutableArithmetics.MutableArithmetics2

module ImmutableArithmetics
    macro rewrite(expr)
        return esc(expr)
    end
end

import BenchmarkTools
import DataFrames
import JuMP
import LinearAlgebra
import PrettyTables
import Printf
import SparseArrays
import Test

const TYPES = (Int32, Float64, BigInt, BigFloat, JuMP.VariableRef)

function run_benchmarks(types = TYPES)
    lookup =
        Dict("bench_MA_" => "MA", "bench_IM_" => "IM", "bench_MA2_" => "MA2")
    suite = BenchmarkTools.BenchmarkGroup()
    for T in types
        suite["$T"] = BenchmarkTools.BenchmarkGroup()
        for v in values(lookup)
            suite["$T"][v] = BenchmarkTools.BenchmarkGroup()
        end
    end
    for name in names(@__MODULE__; all = true)
        f = getfield(@__MODULE__, name)
        for T in types
            for (k, v) in lookup
                if startswith("$name", k)
                    fname = replace("$name", k => "")
                    suite["$T"][v][fname] = BenchmarkTools.@benchmarkable $f($T)
                    break
                end
            end
        end
    end
    return BenchmarkTools.run(suite)
end

function run_benchmark(T, name::String)
    lookup =
        Dict("bench_MA_" => "MA", "bench_IM_" => "IM", "bench_MA2_" => "MA2")
    suite = BenchmarkTools.BenchmarkGroup()
    for (k, v) in lookup
        f = getfield(@__MODULE__, Symbol("$k$name"))
        suite[v] = BenchmarkTools.@benchmarkable $f($T)
    end
    return BenchmarkTools.run(suite)
end

function print_results(f, results)
    function compare(t, m, name, result)
        if m == "MA"
            return Printf.@sprintf("%1.2e", f(result))
        end
        value = f(result)
        diff = value / f(results["$t"]["MA"][name])
        str = Printf.@sprintf("%1.2e (%.2f)", value, diff)
        stars = string(
            diff > 1.05 ? "+" : "",
            diff > 1.1 ? "+" : "",
            diff < 0.9 ? "-" : "",
            diff < 0.95 ? "-" : "",
        )
        return string(str, ' ', rpad(stars, 2))
    end
    df = DataFrames.DataFrame([
        (
            type=t,
            method="$m",
            name=name,
            result = compare(t, m, name, result),
        ) for (t, v) in results
        for (m, vv) in v
        for (name, result) in vv
    ])
    df = DataFrames.unstack(df, [:type, :name], :method, :result)
    DataFrames.sort!(df, [:type, :name])
    DataFrames.sort!(df, [:name, :type])
    return PrettyTables.pretty_table(df; crop = :none)
end

function run_tests(; atol = 0.0)
    tests = Set{String}()
    for name in names(@__MODULE__; all = true)
        if startswith("$name", "bench_IM_")
            push!(tests, replace("$name", "bench_IM_" => ""))
        end
    end
    Test.@testset "Tests" begin
        Test.@testset "$T" for T in TYPES
            Test.@testset "$M" for M in ("MA", "MA2")
                Test.@testset "$test" for test in tests
                    f = getfield(@__MODULE__, Symbol("bench_IM_$test"))
                    g = getfield(@__MODULE__, Symbol("bench_$(M)_$test"))
                    Test.@test _is_approx(f(T), g(T); atol = atol)
                end
            end
        end
    end
    return
end

_is_approx(::Nothing, ::Nothing; kwargs...) = true

_is_approx(x, y; kwargs...) = isapprox(x, y; kwargs...)

function _is_approx(x::AbstractArray, y::AbstractArray; kwargs...)
    return all(_is_approx.(x, y; kwargs...))
end

function _is_approx(x::T, y::T; kwargs...) where {T<:JuMP.AbstractJuMPScalar}
    return string(x) == string(y)
end

new_element(::Type{T}, n) where {T} = [T(1) for _ in 1:n]

function new_element(::Type{JuMP.VariableRef}, n)
    model = JuMP.Model()
    return JuMP.@variable(model, x[1:n])
end

supports_nonlinear(::Type{T}) where {T} = true

supports_nonlinear(::Type{JuMP.VariableRef}) = false

for (f, M) in Dict(
    :MA => MutableArithmetics,
    :MA2 => MutableArithmetics2,
    :IM => ImmutableArithmetics,
)
    @eval begin
        function $(Symbol("bench_$(f)_plus"))(T)
            x = new_element(T, 2)
            $M.@rewrite(x[1] + x[2])
        end
        function $(Symbol("bench_$(f)_pow"))(T)
            if !supports_nonlinear(T)
                return
            end
            x = new_element(T, 2)
            $M.@rewrite(x[1]^x[2])
        end
        function $(Symbol("bench_$(f)_summation_generator"))(T)
            x = new_element(T, 2)
            $M.@rewrite(sum(j + x[i] for i in 1:2, j in 1:2) / 0.1)
        end
        function $(Symbol("bench_$(f)_summation_flatten"))(T)
            x = new_element(T, 2)
            $M.@rewrite(sum(j + x[i] for i in 1:2 for j in 1:i) / 0.1)
        end
        function $(Symbol("bench_$(f)_summation_affine"))(T)
            n = 10_000
            x = new_element(T, n)
            $M.@rewrite(sum(i * x[i] for i in 1:n))
        end
        function $(Symbol("bench_$(f)_summation_quadratic"))(T)
            n = 10_000
            x = new_element(T, n)
            $M.@rewrite(sum(i * x[i]^2 for i in 1:n))
        end
        function $(Symbol("bench_$(f)_summation_affine_post_mult"))(T)
            n = 10_000
            x = new_element(T, n)
            $M.@rewrite(sum(i * x[i] for i in 1:n) * 2)
        end
        function $(Symbol("bench_$(f)_summation_quadratic_post_mult"))(T)
            n = 10_000
            x = new_element(T, n)
            $M.@rewrite(sum(i * x[i]^2 for i in 1:n) * 2)
        end
        function $(Symbol("bench_$(f)_summation_vector"))(T)
            x = new_element(T, 10_000)
            $M.@rewrite(sum(x))
        end
        function $(Symbol("bench_$(f)_summation_affine_minus_quad"))(T)
            n = 100
            x = new_element(T, n)
            $M.@rewrite(sum(x[i] for i in 1:n) - sum(x[i]^2 for i in 1:n))
        end
        function $(Symbol("bench_$(f)_summation_affine_plus_quad"))(T)
            n = 100
            x = new_element(T, n)
            $M.@rewrite(sum(x[i] for i in 1:n) + sum(x[i]^2 for i in 1:n))
        end
        function $(Symbol("bench_$(f)_summation_quad_minus_affine"))(T)
            n = 100
            x = new_element(T, n)
            $M.@rewrite(sum(x[i]^2 for i in 1:n) - sum(x[i] for i in 1:n))
        end
        function $(Symbol("bench_$(f)_summation_quad_plus_affine"))(T)
            n = 100
            x = new_element(T, n)
            $M.@rewrite(sum(x[i]^2 for i in 1:n) + sum(x[i] for i in 1:n))
        end
        function $(Symbol("bench_$(f)_broadcast_vector_square"))(T)
            x = new_element(T, 10_000)
            $M.@rewrite(sum(x .^ 2))
        end
        function $(Symbol("bench_$(f)_vector_product"))(T)
            x = new_element(T, 10_000)
            $M.@rewrite(x' * x)
        end
        function $(Symbol("bench_$(f)_matrix_symmetric"))(T)
            n = 100
            x = reshape(new_element(T, n^2), n, n)
            A = LinearAlgebra.Symmetric(x)
            $M.@rewrite(sum(A.^2))
        end
        function $(Symbol("bench_$(f)_matrix_lowertriangular"))(T)
            n = 100
            x = reshape(new_element(T, n^2), n, n)
            A = LinearAlgebra.LowerTriangular(x)
            $M.@rewrite(sum(A.^2))
        end
        function $(Symbol("bench_$(f)_matrix_uppertriangular"))(T)
            if !supports_nonlinear(T)
                return  # MA Issue#66
            end
            n = 100
            x = reshape(new_element(T, n^2), n, n)
            A = LinearAlgebra.UpperTriangular(x)
            b = collect(1:n)
            $M.@rewrite((A * b) .* b)
        end
        function $(Symbol("bench_$(f)_broadcast_dot"))(T)
            x = new_element(T, 10_000)
            $M.@rewrite(sum(x .* x))
        end
        function $(Symbol("bench_$(f)_dot"))(T)
            x = new_element(T, 10_000)
            $M.@rewrite(LinearAlgebra.dot(x, x))
        end
        function $(Symbol("bench_$(f)_broadcast_vector_add_mul"))(T)
            x = new_element(T, 10_000)
            $M.@rewrite(x .+ 3 .* x)
        end
        function $(Symbol("bench_$(f)_broadcast_vector_sub_mul"))(T)
            x = new_element(T, 10_000)
            $M.@rewrite(x .- 3 .* x)
        end
        function $(Symbol("bench_$(f)_broadcast_matrix_vector_add_mul"))(T)
            n = 1_000
            x = new_element(T, n)
            A = reshape(1:n^2, n, n)
            $M.@rewrite(x .+ A * x)
        end
        function $(Symbol("bench_$(f)_broadcast_matrix_vector_sub_mul"))(T)
            n = 1_000
            x = new_element(T, n)
            A = reshape(1:n^2, n, n)
            $M.@rewrite(x .- A * x)
        end
        function $(Symbol("bench_$(f)_broadcast_matrix_add"))(T)
            n = 1_000
            x = new_element(T, n^2)
            A = reshape(x, n, n)
            $M.@rewrite(A .+ 3 .* A')
        end
        function $(Symbol("bench_$(f)_mle"))(T)
            if !supports_nonlinear(T)
                return
            end
            n = 1_000
            μ, σ = new_element(T, 2)
            data = [sin(Float64(i)) for i in 1:n]
            $M.@rewrite(
                n / 2 * log(1 / (2 * π * σ^2)) -
                sum((data[i] - μ)^2 for i in 1:n) / (2 * σ^2)
            )
        end
    end
end

end

Results

Values in () are the ratio relative to "MA" column.

julia> BenchmarkMA.print_results(r -> 1e-9 * minimum(r).time, results)
┌──────────────────┬─────────────────────────────────┬──────────────────────┬────────────────────┬──────────┐
│             type │                            name │                   IM │                MA2 │       MA │
│           String │                          String │              String? │            String? │  String? │
├──────────────────┼─────────────────────────────────┼──────────────────────┼────────────────────┼──────────┤
│         BigFloat │                   broadcast_dot │   9.15e-04 (1.00)    │ 9.15e-04 (1.00)    │ 9.15e-04 │
│         BigFloat │            broadcast_matrix_add │   1.34e-01 (1.11) ++1.20e-01 (0.99)    │ 1.21e-01 │
│         BigFloat │ broadcast_matrix_vector_add_mul │   7.26e-02 (0.99)    │ 7.27e-02 (0.99)    │ 7.32e-02 │
│         BigFloat │ broadcast_matrix_vector_sub_mul │   7.24e-02 (1.00)    │ 7.25e-02 (1.00)    │ 7.27e-02 │
│         BigFloat │        broadcast_vector_add_mul │   1.08e-03 (1.12) ++9.60e-04 (1.00)    │ 9.60e-04 │
│         BigFloat │         broadcast_vector_square │   9.01e-04 (1.00)    │ 9.01e-04 (1.00)    │ 9.01e-04 │
│         BigFloat │        broadcast_vector_sub_mul │   9.82e-04 (1.10) +8.96e-04 (1.00)    │ 8.95e-04 │
│         BigFloat │                             dot │   1.13e-03 (1.00)    │ 1.13e-03 (1.00)    │ 1.13e-03 │
│         BigFloat │          matrix_lowertriangular │   7.26e-04 (1.00)    │ 7.26e-04 (1.00)    │ 7.26e-04 │
│         BigFloat │                matrix_symmetric │   9.07e-04 (1.00)    │ 9.07e-04 (1.00)    │ 9.08e-04 │
│         BigFloat │          matrix_uppertriangular │   7.41e-04 (1.00)    │ 7.42e-04 (1.00)    │ 7.44e-04 │
│         BigFloat │                             mle │   1.66e-04 (0.38) --1.78e-04 (0.41) --4.36e-04 │
│         BigFloat │                            plus │   1.66e-07 (0.57) --2.50e-07 (0.86) --2.91e-07 │
│         BigFloat │                             pow │   1.66e-07 (0.44) --1.25e-07 (0.33) --3.75e-07 │
│         BigFloat │                summation_affine │   9.49e-04 (0.45) --2.10e-03 (0.99)    │ 2.13e-03 │
│         BigFloat │     summation_affine_minus_quad │   1.45e-05 (0.84) --1.66e-05 (0.97)    │ 1.72e-05 │
│         BigFloat │      summation_affine_plus_quad │   1.45e-05 (0.84) --1.68e-05 (0.98)    │ 1.72e-05 │
│         BigFloat │      summation_affine_post_mult │   9.49e-04 (0.30) --2.10e-03 (0.67) --3.15e-03 │
│         BigFloat │               summation_flatten │   4.16e-07 (0.32) --7.50e-07 (0.58) --1.29e-06 │
│         BigFloat │             summation_generator │   7.50e-07 (0.45) --1.00e-06 (0.60) --1.67e-06 │
│         BigFloat │     summation_quad_minus_affine │   1.45e-05 (0.86) --1.64e-05 (0.98)    │ 1.68e-05 │
│         BigFloat │      summation_quad_plus_affine │   1.45e-05 (0.87) --1.62e-05 (0.97)    │ 1.66e-05 │
│         BigFloat │             summation_quadratic │   1.49e-03 (0.54) --2.77e-03 (1.01)    │ 2.75e-03 │
│         BigFloat │   summation_quadratic_post_mult │   1.49e-03 (0.40) --2.82e-03 (0.75) --3.77e-03 │
│         BigFloat │                summation_vector │   3.98e-04 (1.00)    │ 3.99e-04 (1.00)    │ 3.99e-04 │
│         BigFloat │                  vector_product │   1.13e-03 (1.47) ++1.13e-03 (1.47) ++7.72e-04 │
│           BigInt │                   broadcast_dot │   8.23e-04 (1.00)    │ 8.24e-04 (1.00)    │ 8.22e-04 │
│           BigInt │            broadcast_matrix_add │   1.17e-01 (0.86) --1.37e-01 (1.00)    │ 1.36e-01 │
│           BigInt │ broadcast_matrix_vector_add_mul │   7.86e-02 (1.00)    │ 7.86e-02 (1.00)    │ 7.86e-02 │
│           BigInt │ broadcast_matrix_vector_sub_mul │   7.88e-02 (1.00)    │ 7.86e-02 (1.00)    │ 7.85e-02 │
│           BigInt │        broadcast_vector_add_mul │   1.09e-03 (0.86) --1.27e-03 (1.00)    │ 1.28e-03 │
│           BigInt │         broadcast_vector_square │   3.76e-04 (1.00)    │ 3.77e-04 (1.00)    │ 3.77e-04 │
│           BigInt │        broadcast_vector_sub_mul │   1.09e-03 (0.86) --1.27e-03 (1.00)    │ 1.27e-03 │
│           BigInt │                             dot │   1.09e-03 (1.00)    │ 1.10e-03 (1.00)    │ 1.10e-03 │
│           BigInt │          matrix_lowertriangular │   5.05e-04 (1.00)    │ 5.05e-04 (1.00)    │ 5.04e-04 │
│           BigInt │                matrix_symmetric │   3.87e-04 (1.00)    │ 3.87e-04 (1.00)    │ 3.87e-04 │
│           BigInt │          matrix_uppertriangular │   7.04e-04 (1.00)    │ 7.03e-04 (0.99)    │ 7.07e-04 │
│           BigInt │                             mle │   2.02e-04 (0.42) --2.26e-04 (0.47) --4.85e-04 │
│           BigInt │                            plus │   1.66e-07 (0.57) --2.91e-07 (1.00)    │ 2.91e-07 │
│           BigInt │                             pow │   1.25e-07 (1.00)    │ 1.25e-07 (1.00)    │ 1.25e-07 │
│           BigInt │                summation_affine │   1.09e-03 (0.52) --2.07e-03 (0.99)    │ 2.09e-03 │
│           BigInt │     summation_affine_minus_quad │   1.00e-05 (1.00)    │ 9.79e-06 (0.98)    │ 9.96e-06 │
│           BigInt │      summation_affine_plus_quad │   1.00e-05 (0.96)    │ 1.01e-05 (0.96)    │ 1.05e-05 │
│           BigInt │      summation_affine_post_mult │   1.09e-03 (0.36) --2.07e-03 (0.69) --2.99e-03 │
│           BigInt │               summation_flatten │   7.91e-07 (0.56) --8.75e-07 (0.62) --1.42e-06 │
│           BigInt │             summation_generator │   8.33e-07 (0.43) --1.04e-06 (0.54) --1.92e-06 │
│           BigInt │     summation_quad_minus_affine │   1.00e-05 (1.00)    │ 9.88e-06 (0.98)    │ 1.01e-05 │
│           BigInt │      summation_quad_plus_affine │   1.00e-05 (0.95)    │ 9.88e-06 (0.94) -1.05e-05 │
│           BigInt │             summation_quadratic │   1.15e-03 (0.51) --2.24e-03 (1.00)    │ 2.25e-03 │
│           BigInt │   summation_quadratic_post_mult │   1.15e-03 (0.36) --2.23e-03 (0.70) --3.20e-03 │
│           BigInt │                summation_vector │   3.41e-04 (1.00)    │ 3.41e-04 (1.00)    │ 3.41e-04 │
│           BigInt │                  vector_product │   1.10e-03 (2.81) ++1.10e-03 (2.82) ++3.91e-04 │
│          Float64 │                   broadcast_dot │   6.17e-06 (1.03)    │ 6.08e-06 (1.01)    │ 6.00e-06 │
│          Float64 │            broadcast_matrix_add │   7.76e-04 (0.45) --2.08e-03 (1.21) ++1.72e-03 │
│          Float64 │ broadcast_matrix_vector_add_mul │   3.07e-04 (1.00)    │ 3.06e-04 (1.00)    │ 3.06e-04 │
│          Float64 │ broadcast_matrix_vector_sub_mul │   3.07e-04 (1.00)    │ 3.06e-04 (1.00)    │ 3.06e-04 │
│          Float64 │        broadcast_vector_add_mul │   4.50e-06 (0.53) --8.58e-06 (1.00)    │ 8.54e-06 │
│          Float64 │         broadcast_vector_square │   6.00e-06 (0.98)    │ 6.04e-06 (0.99)    │ 6.13e-06 │
│          Float64 │        broadcast_vector_sub_mul │   4.46e-06 (0.52) --8.63e-06 (1.01)    │ 8.54e-06 │
│          Float64 │                             dot │   6.25e-06 (1.00)    │ 6.25e-06 (1.00)    │ 6.25e-06 │
│          Float64 │          matrix_lowertriangular │   1.76e-05 (1.00)    │ 1.77e-05 (1.00)    │ 1.77e-05 │
│          Float64 │                matrix_symmetric │   1.23e-05 (1.00)    │ 1.23e-05 (1.00)    │ 1.24e-05 │
│          Float64 │          matrix_uppertriangular │   4.29e-06 (0.98)    │ 4.33e-06 (0.99)    │ 4.38e-06 │
│          Float64 │                             mle │   6.75e-06 (0.08) --6.14e-05 (0.71) --8.68e-05 │
│          Float64 │                            plus │   4.10e-08 (1.00)    │ 4.20e-08 (1.02)    │ 4.10e-08 │
│          Float64 │                             pow │   4.20e-08 (1.00)    │ 4.10e-08 (0.98)    │ 4.20e-08 │
│          Float64 │                summation_affine │   1.13e-05 (0.03) --4.16e-04 (1.05) +3.94e-04 │
│          Float64 │     summation_affine_minus_quad │   6.25e-07 (0.07) --8.29e-06 (0.97)    │ 8.58e-06 │
│          Float64 │      summation_affine_plus_quad │   6.25e-07 (0.07) --8.17e-06 (0.95)    │ 8.58e-06 │
│          Float64 │      summation_affine_post_mult │   1.13e-05 (0.03) --4.09e-04 (0.97)    │ 4.21e-04 │
│          Float64 │               summation_flatten │   3.75e-07 (1.28) ++2.92e-07 (1.00)    │ 2.92e-07 │
│          Float64 │             summation_generator │   2.91e-07 (0.78) --3.75e-07 (1.00)    │ 3.75e-07 │
│          Float64 │     summation_quad_minus_affine │   5.83e-07 (0.07) --8.79e-06 (1.06) +8.29e-06 │
│          Float64 │      summation_quad_plus_affine │   6.25e-07 (0.08) --8.29e-06 (1.01)    │ 8.21e-06 │
│          Float64 │             summation_quadratic │   1.13e-05 (0.02) --5.64e-04 (0.96)    │ 5.86e-04 │
│          Float64 │   summation_quadratic_post_mult │   1.13e-05 (0.02) --5.95e-04 (1.08) +5.53e-04 │
│          Float64 │                summation_vector │   3.46e-06 (1.00)    │ 3.46e-06 (1.00)    │ 3.46e-06 │
│          Float64 │                  vector_product │   6.25e-06 (0.57) --6.25e-06 (0.57) --1.09e-05 │
│            Int32 │                   broadcast_dot │   3.79e-06 (0.99)    │ 3.79e-06 (0.99)    │ 3.83e-06 │
│            Int32 │            broadcast_matrix_add │   5.30e-04 (0.62) --1.23e-03 (1.44) ++8.55e-04 │
│            Int32 │ broadcast_matrix_vector_add_mul │   4.74e-04 (1.00)    │ 4.73e-04 (1.00)    │ 4.73e-04 │
│            Int32 │ broadcast_matrix_vector_sub_mul │   4.74e-04 (1.00)    │ 4.73e-04 (1.00)    │ 4.73e-04 │
│            Int32 │        broadcast_vector_add_mul │   4.67e-06 (0.55) --8.42e-06 (1.00)    │ 8.42e-06 │
│            Int32 │         broadcast_vector_square │   3.71e-06 (0.96)    │ 3.88e-06 (1.00)    │ 3.88e-06 │
│            Int32 │        broadcast_vector_sub_mul │   4.79e-06 (0.57) --8.50e-06 (1.00)    │ 8.46e-06 │
│            Int32 │                             dot │   2.12e-06 (1.00)    │ 2.12e-06 (1.00)    │ 2.12e-06 │
│            Int32 │          matrix_lowertriangular │   1.62e-05 (0.99)    │ 1.62e-05 (0.99)    │ 1.63e-05 │
│            Int32 │                matrix_symmetric │   1.16e-05 (1.00)    │ 1.16e-05 (1.00)    │ 1.15e-05 │
│            Int32 │          matrix_uppertriangular │   8.04e-06 (1.01)    │ 7.96e-06 (0.99)    │ 8.00e-06 │
│            Int32 │                             mle │   6.71e-06 (0.08) --5.61e-05 (0.67) --8.35e-05 │
│            Int32 │                            plus │   4.10e-08 (1.00)    │ 4.10e-08 (1.00)    │ 4.10e-08 │
│            Int32 │                             pow │   4.10e-08 (1.00)    │ 1.00e-12 (0.00) --4.10e-08 │
│            Int32 │                summation_affine │   4.42e-06 (0.01) --3.74e-04 (1.09) +3.43e-04 │
│            Int32 │     summation_affine_minus_quad │   4.58e-07 (0.07) --6.33e-06 (0.91) -6.96e-06 │
│            Int32 │      summation_affine_plus_quad │   4.58e-07 (0.07) --6.67e-06 (0.99)    │ 6.71e-06 │
│            Int32 │      summation_affine_post_mult │   4.42e-06 (0.01) --3.70e-04 (0.99)    │ 3.74e-04 │
│            Int32 │               summation_flatten │   3.33e-07 (1.33) ++2.91e-07 (1.16) ++2.50e-07 │
│            Int32 │             summation_generator │   3.33e-07 (1.00)    │ 2.92e-07 (0.88) --3.33e-07 │
│            Int32 │     summation_quad_minus_affine │   4.58e-07 (0.07) --6.88e-06 (0.98)    │ 7.04e-06 │
│            Int32 │      summation_quad_plus_affine │   4.58e-07 (0.07) --6.83e-06 (1.01)    │ 6.79e-06 │
│            Int32 │             summation_quadratic │   5.92e-06 (0.01) --5.19e-04 (1.01)    │ 5.16e-04 │
│            Int32 │   summation_quadratic_post_mult │   5.96e-06 (0.01) --5.38e-04 (1.03)    │ 5.22e-04 │
│            Int32 │                summation_vector │   2.25e-06 (1.00)    │ 2.25e-06 (1.00)    │ 2.25e-06 │
│            Int32 │                  vector_product │   1.96e-06 (0.22) --2.12e-06 (0.24) --8.75e-06 │
│ JuMP.VariableRef │                   broadcast_dot │   4.56e-03 (0.90) --4.51e-03 (0.89) --5.09e-03 │
│ JuMP.VariableRef │            broadcast_matrix_add │   9.25e-01 (1.01)    │ 9.08e-01 (0.99)    │ 9.14e-01 │
│ JuMP.VariableRef │ broadcast_matrix_vector_add_mul │   1.25e-01 (0.99)    │ 1.26e-01 (1.00)    │ 1.26e-01 │
│ JuMP.VariableRef │ broadcast_matrix_vector_sub_mul │   1.23e-01 (0.99)    │ 1.23e-01 (0.99)    │ 1.24e-01 │
│ JuMP.VariableRef │        broadcast_vector_add_mul │   4.55e-03 (1.00)    │ 4.57e-03 (1.00)    │ 4.56e-03 │
│ JuMP.VariableRef │         broadcast_vector_square │   4.49e-03 (0.88) --4.53e-03 (0.89) --5.11e-03 │
│ JuMP.VariableRef │        broadcast_vector_sub_mul │   4.52e-03 (1.00)    │ 4.52e-03 (1.00)    │ 4.53e-03 │
│ JuMP.VariableRef │                             dot │   2.06e-03 (0.79) --2.08e-03 (0.79) --2.61e-03 │
│ JuMP.VariableRef │          matrix_lowertriangular │   3.66e-03 (0.92) -3.71e-03 (0.93) -3.99e-03 │
│ JuMP.VariableRef │                matrix_symmetric │   4.46e-03 (0.93) -4.50e-03 (0.94) -4.81e-03 │
│ JuMP.VariableRef │          matrix_uppertriangular │   4.10e-08 (1.00)    │ 4.10e-08 (1.00)    │ 4.10e-08 │
│ JuMP.VariableRef │                             mle │   4.10e-08 (0.98)    │ 4.10e-08 (0.98)    │ 4.20e-08 │
│ JuMP.VariableRef │                            plus │   1.96e-06 (1.00)    │ 1.96e-06 (1.00)    │ 1.96e-06 │
│ JuMP.VariableRef │                             pow │   4.10e-08 (1.00)    │ 4.20e-08 (1.02)    │ 4.10e-08 │
│ JuMP.VariableRef │                summation_affine │ 2.80e-01 (122.02) ++2.30e-03 (1.00)    │ 2.30e-03 │
│ JuMP.VariableRef │     summation_affine_minus_quad │   1.47e-04 (2.41) ++6.60e-05 (1.08) +6.11e-05 │
│ JuMP.VariableRef │      summation_affine_plus_quad │   1.49e-04 (2.46) ++6.09e-05 (1.01)    │ 6.05e-05 │
│ JuMP.VariableRef │      summation_affine_post_mult │ 2.70e-01 (117.76) ++2.52e-03 (1.10) +2.29e-03 │
│ JuMP.VariableRef │               summation_flatten │   2.96e-06 (1.29) ++2.33e-06 (1.02)    │ 2.29e-06 │
│ JuMP.VariableRef │             summation_generator │   3.08e-06 (1.30) ++2.42e-06 (1.02)    │ 2.38e-06 │
│ JuMP.VariableRef │     summation_quad_minus_affine │   1.47e-04 (2.41) ++6.60e-05 (1.09) +6.08e-05 │
│ JuMP.VariableRef │      summation_quad_plus_affine │   1.43e-04 (2.35) ++6.54e-05 (1.08) +6.07e-05 │
│ JuMP.VariableRef │             summation_quadratic │  3.28e-01 (64.29) ++5.10e-03 (1.00)    │ 5.10e-03 │
│ JuMP.VariableRef │   summation_quadratic_post_mult │  3.26e-01 (62.85) ++5.44e-03 (1.05)    │ 5.18e-03 │
│ JuMP.VariableRef │                summation_vector │   1.88e-03 (0.83) --1.86e-03 (0.82) --2.27e-03 │
│ JuMP.VariableRef │                  vector_product │   2.09e-03 (1.00)    │ 2.07e-03 (1.00)    │ 2.08e-03 │
└──────────────────┴─────────────────────────────────┴──────────────────────┴────────────────────┴──────────┘

julia> BenchmarkMA.print_results(r -> minimum(r).memory / 1024, results)
┌──────────────────┬─────────────────────────────────┬──────────────────────┬────────────────────┬──────────┐
│             type │                            name │                   IM │                MA2 │       MA │
│           String │                          String │              String? │            String? │  String? │
├──────────────────┼─────────────────────────────────┼──────────────────────┼────────────────────┼──────────┤
│         BigFloat │                   broadcast_dot │   2.19e+03 (1.00)    │ 2.19e+03 (1.00)    │ 2.19e+03 │
│         BigFloat │            broadcast_matrix_add │   3.20e+05 (0.98)    │ 3.28e+05 (1.00)    │ 3.28e+05 │
│         BigFloat │ broadcast_matrix_vector_add_mul │   2.04e+05 (1.00)    │ 2.04e+05 (1.00)    │ 2.04e+05 │
│         BigFloat │ broadcast_matrix_vector_sub_mul │   2.04e+05 (1.00)    │ 2.04e+05 (1.00)    │ 2.04e+05 │
│         BigFloat │        broadcast_vector_add_mul │   3.20e+03 (0.98)    │ 3.28e+03 (1.00)    │ 3.28e+03 │
│         BigFloat │         broadcast_vector_square │   2.19e+03 (1.00)    │ 2.19e+03 (1.00)    │ 2.19e+03 │
│         BigFloat │        broadcast_vector_sub_mul │   3.20e+03 (0.98)    │ 3.28e+03 (1.00)    │ 3.28e+03 │
│         BigFloat │                             dot │   3.13e+03 (1.00)    │ 3.13e+03 (1.00)    │ 3.13e+03 │
│         BigFloat │          matrix_lowertriangular │   2.19e+03 (1.00)    │ 2.19e+03 (1.00)    │ 2.19e+03 │
│         BigFloat │                matrix_symmetric │   2.19e+03 (1.00)    │ 2.19e+03 (1.00)    │ 2.19e+03 │
│         BigFloat │          matrix_uppertriangular │   2.13e+03 (0.99)    │ 2.13e+03 (0.99)    │ 2.14e+03 │
│         BigFloat │                             mle │   3.14e+02 (0.59) --2.28e+02 (0.43) --5.32e+02 │
│         BigFloat │                            plus │   3.67e-01 (1.00)    │ 3.67e-01 (1.00)    │ 3.67e-01 │
│         BigFloat │                             pow │   3.67e-01 (0.78) --3.67e-01 (0.78) --4.69e-01 │
│         BigFloat │                summation_affine │   3.12e+03 (0.57) --5.45e+03 (1.00)    │ 5.45e+03 │
│         BigFloat │     summation_affine_minus_quad │   4.15e+01 (1.95) ++2.14e+01 (1.00)    │ 2.13e+01 │
│         BigFloat │      summation_affine_plus_quad │   4.15e+01 (1.95) ++2.14e+01 (1.00)    │ 2.13e+01 │
│         BigFloat │      summation_affine_post_mult │   3.13e+03 (0.37) --5.45e+03 (0.64) --8.50e+03 │
│         BigFloat │               summation_flatten │   8.75e-01 (0.35) --1.48e+00 (0.59) --2.52e+00 │
│         BigFloat │             summation_generator │   1.14e+00 (0.33) --1.99e+00 (0.58) --3.43e+00 │
│         BigFloat │     summation_quad_minus_affine │   4.15e+01 (1.95) ++2.14e+01 (1.00)    │ 2.13e+01 │
│         BigFloat │      summation_quad_plus_affine │   4.15e+01 (1.95) ++2.14e+01 (1.00)    │ 2.13e+01 │
│         BigFloat │             summation_quadratic │   4.14e+03 (0.64) --6.47e+03 (1.00)    │ 6.47e+03 │
│         BigFloat │   summation_quadratic_post_mult │   4.14e+03 (0.44) --6.47e+03 (0.68) --9.51e+03 │
│         BigFloat │                summation_vector │   1.09e+03 (1.00)    │ 1.09e+03 (1.00)    │ 1.09e+03 │
│         BigFloat │                  vector_product │   3.13e+03 (2.86) ++3.13e+03 (2.86) ++1.09e+03 │
│           BigInt │                   broadcast_dot │   1.02e+03 (1.00)    │ 1.02e+03 (1.00)    │ 1.02e+03 │
│           BigInt │            broadcast_matrix_add │   1.48e+05 (0.95)    │ 1.56e+05 (1.00)    │ 1.56e+05 │
│           BigInt │ broadcast_matrix_vector_add_mul │   9.40e+04 (1.00)    │ 9.40e+04 (1.00)    │ 9.40e+04 │
│           BigInt │ broadcast_matrix_vector_sub_mul │   9.40e+04 (1.00)    │ 9.40e+04 (1.00)    │ 9.40e+04 │
│           BigInt │        broadcast_vector_add_mul │   1.48e+03 (0.95)    │ 1.56e+03 (1.00)    │ 1.56e+03 │
│           BigInt │         broadcast_vector_square │   5.47e+02 (1.00)    │ 5.47e+02 (1.00)    │ 5.47e+02 │
│           BigInt │        broadcast_vector_sub_mul │   1.48e+03 (0.95)    │ 1.56e+03 (1.00)    │ 1.56e+03 │
│           BigInt │                             dot │   1.41e+03 (1.00)    │ 1.41e+03 (1.00)    │ 1.41e+03 │
│           BigInt │          matrix_lowertriangular │   7.41e+02 (1.00)    │ 7.41e+02 (1.00)    │ 7.41e+02 │
│           BigInt │                matrix_symmetric │   5.47e+02 (1.00)    │ 5.47e+02 (1.00)    │ 5.47e+02 │
│           BigInt │          matrix_uppertriangular │   9.49e+02 (0.99)    │ 9.49e+02 (0.99)    │ 9.54e+02 │
│           BigInt │                             mle │   5.17e+02 (0.76) --4.31e+02 (0.63) --6.81e+02 │
│           BigInt │                            plus │   1.88e-01 (1.00)    │ 1.88e-01 (1.00)    │ 1.88e-01 │
│           BigInt │                             pow │   1.41e-01 (0.78) --1.41e-01 (0.78) --1.80e-01 │
│           BigInt │                summation_affine │   1.41e+03 (0.58) --2.41e+03 (1.00)    │ 2.41e+03 │
│           BigInt │     summation_affine_minus_quad │   1.42e+01 (2.94) ++4.88e+00 (1.01)    │ 4.83e+00 │
│           BigInt │      summation_affine_plus_quad │   1.42e+01 (2.94) ++4.88e+00 (1.01)    │ 4.83e+00 │
│           BigInt │      summation_affine_post_mult │   1.41e+03 (0.39) --2.41e+03 (0.67) --3.58e+03 │
│           BigInt │               summation_flatten │   7.58e-01 (0.25) --8.83e-01 (0.30) --2.98e+00 │
│           BigInt │             summation_generator │   8.36e-01 (0.20) --1.08e+00 (0.26) --4.13e+00 │
│           BigInt │     summation_quad_minus_affine │   1.42e+01 (2.94) ++4.88e+00 (1.01)    │ 4.83e+00 │
│           BigInt │      summation_quad_plus_affine │   1.42e+01 (2.94) ++4.88e+00 (1.01)    │ 4.83e+00 │
│           BigInt │             summation_quadratic │   1.41e+03 (0.58) --2.41e+03 (1.00)    │ 2.41e+03 │
│           BigInt │   summation_quadratic_post_mult │   1.41e+03 (0.39) --2.41e+03 (0.67) --3.58e+03 │
│           BigInt │                summation_vector │   4.69e+02 (1.00)    │ 4.69e+02 (1.00)    │ 4.69e+02 │
│           BigInt │                  vector_product │   1.41e+03 (3.00) ++1.41e+03 (3.00) ++4.69e+02 │
│          Float64 │                   broadcast_dot │   1.56e+02 (1.00)    │ 1.56e+02 (1.00)    │ 1.56e+02 │
│          Float64 │            broadcast_matrix_add │   1.56e+04 (0.67) --2.34e+04 (1.00)    │ 2.34e+04 │
│          Float64 │ broadcast_matrix_vector_add_mul │   2.39e+01 (1.00)    │ 2.38e+01 (1.00)    │ 2.38e+01 │
│          Float64 │ broadcast_matrix_vector_sub_mul │   2.39e+01 (1.00)    │ 2.38e+01 (1.00)    │ 2.38e+01 │
│          Float64 │        broadcast_vector_add_mul │   1.56e+02 (0.67) --2.35e+02 (1.00)    │ 2.35e+02 │
│          Float64 │         broadcast_vector_square │   1.56e+02 (1.00)    │ 1.56e+02 (1.00)    │ 1.56e+02 │
│          Float64 │        broadcast_vector_sub_mul │   1.56e+02 (0.67) --2.35e+02 (1.00)    │ 2.35e+02 │
│          Float64 │                             dot │   7.82e+01 (1.00)    │ 7.82e+01 (1.00)    │ 7.82e+01 │
│          Float64 │          matrix_lowertriangular │   1.57e+02 (1.00)    │ 1.57e+02 (1.00)    │ 1.57e+02 │
│          Float64 │                matrix_symmetric │   1.57e+02 (1.00)    │ 1.57e+02 (1.00)    │ 1.57e+02 │
│          Float64 │          matrix_uppertriangular │   8.10e+01 (0.99)    │ 8.10e+01 (0.99)    │ 8.18e+01 │
│          Float64 │                             mle │   8.30e+00 (0.08) --7.07e+01 (0.69) --1.02e+02 │
│          Float64 │                            plus │   9.38e-02 (1.00)    │ 9.38e-02 (1.00)    │ 9.38e-02 │
│          Float64 │                             pow │   9.38e-02 (1.00)    │ 9.38e-02 (1.00)    │ 9.38e-02 │
│          Float64 │                summation_affine │   7.82e+01 (0.11) --6.87e+02 (1.00)    │ 6.87e+02 │
│          Float64 │     summation_affine_minus_quad │   1.03e+00 (0.12) --8.72e+00 (1.00)    │ 8.70e+00 │
│          Float64 │      summation_affine_plus_quad │   1.03e+00 (0.12) --8.72e+00 (1.00)    │ 8.70e+00 │
│          Float64 │      summation_affine_post_mult │   7.83e+01 (0.11) --6.87e+02 (1.00)    │ 6.87e+02 │
│          Float64 │               summation_flatten │   2.03e-01 (0.93) -2.19e-01 (1.00)    │ 2.19e-01 │
│          Float64 │             summation_generator │   1.88e-01 (0.67) --2.81e-01 (1.00)    │ 2.81e-01 │
│          Float64 │     summation_quad_minus_affine │   1.03e+00 (0.12) --8.72e+00 (1.00)    │ 8.70e+00 │
│          Float64 │      summation_quad_plus_affine │   1.03e+00 (0.12) --8.72e+00 (1.00)    │ 8.70e+00 │
│          Float64 │             summation_quadratic │   7.82e+01 (0.09) --8.43e+02 (1.00)    │ 8.43e+02 │
│          Float64 │   summation_quadratic_post_mult │   7.83e+01 (0.09) --8.43e+02 (1.00)    │ 8.43e+02 │
│          Float64 │                summation_vector │   7.82e+01 (1.00)    │ 7.82e+01 (1.00)    │ 7.82e+01 │
│          Float64 │                  vector_product │   7.82e+01 (1.00)    │ 7.82e+01 (1.00)    │ 7.82e+01 │
│            Int32 │                   broadcast_dot │   7.83e+01 (1.00)    │ 7.83e+01 (1.00)    │ 7.83e+01 │
│            Int32 │            broadcast_matrix_add │   1.17e+04 (0.50) --2.34e+04 (1.00)    │ 2.34e+04 │
│            Int32 │ broadcast_matrix_vector_add_mul │   2.00e+01 (0.83) --2.40e+01 (1.00)    │ 2.40e+01 │
│            Int32 │ broadcast_matrix_vector_sub_mul │   2.00e+01 (0.83) --2.40e+01 (1.00)    │ 2.40e+01 │
│            Int32 │        broadcast_vector_add_mul │   1.17e+02 (0.50) --2.35e+02 (1.00)    │ 2.35e+02 │
│            Int32 │         broadcast_vector_square │   7.83e+01 (1.00)    │ 7.83e+01 (1.00)    │ 7.83e+01 │
│            Int32 │        broadcast_vector_sub_mul │   1.17e+02 (0.50) --2.35e+02 (1.00)    │ 2.35e+02 │
│            Int32 │                             dot │   3.91e+01 (1.00)    │ 3.91e+01 (1.00)    │ 3.91e+01 │
│            Int32 │          matrix_lowertriangular │   7.85e+01 (1.00)    │ 7.85e+01 (1.00)    │ 7.85e+01 │
│            Int32 │                matrix_symmetric │   7.85e+01 (1.00)    │ 7.85e+01 (1.00)    │ 7.85e+01 │
│            Int32 │          matrix_uppertriangular │   1.20e+02 (0.99)    │ 1.20e+02 (0.99)    │ 1.21e+02 │
│            Int32 │                             mle │   8.20e+00 (0.12) --7.06e+01 (1.00)    │ 7.06e+01 │
│            Int32 │                            plus │   6.25e-02 (1.00)    │ 6.25e-02 (1.00)    │ 6.25e-02 │
│            Int32 │                             pow │   6.25e-02 (1.00)    │ 6.25e-02 (1.00)    │ 6.25e-02 │
│            Int32 │                summation_affine │   3.92e+01 (0.08) --4.91e+02 (1.00)    │ 4.91e+02 │
│            Int32 │     summation_affine_minus_quad │   5.78e-01 (1.19) ++4.84e-01 (1.00)    │ 4.84e-01 │
│            Int32 │      summation_affine_plus_quad │   5.78e-01 (1.19) ++4.84e-01 (1.00)    │ 4.84e-01 │
│            Int32 │      summation_affine_post_mult │   3.92e+01 (0.08) --4.91e+02 (1.00)    │ 4.92e+02 │
│            Int32 │               summation_flatten │   1.72e-01 (1.10) +7.81e-02 (0.50) --1.56e-01 │
│            Int32 │             summation_generator │   1.56e-01 (0.77) --9.38e-02 (0.46) --2.03e-01 │
│            Int32 │     summation_quad_minus_affine │   5.78e-01 (1.19) ++4.84e-01 (1.00)    │ 4.84e-01 │
│            Int32 │      summation_quad_plus_affine │   5.78e-01 (1.19) ++4.84e-01 (1.00)    │ 4.84e-01 │
│            Int32 │             summation_quadratic │   3.92e+01 (0.08) --4.91e+02 (1.00)    │ 4.91e+02 │
│            Int32 │   summation_quadratic_post_mult │   3.92e+01 (0.08) --4.91e+02 (1.00)    │ 4.92e+02 │
│            Int32 │                summation_vector │   3.91e+01 (1.00)    │ 3.91e+01 (1.00)    │ 3.92e+01 │
│            Int32 │                  vector_product │   3.91e+01 (1.00)    │ 3.91e+01 (1.00)    │ 3.91e+01 │
│ JuMP.VariableRef │                   broadcast_dot │   1.70e+04 (0.90) -1.70e+04 (0.90) -1.87e+04 │
│ JuMP.VariableRef │            broadcast_matrix_add │   1.73e+06 (0.99)    │ 1.75e+06 (1.00)    │ 1.75e+06 │
│ JuMP.VariableRef │ broadcast_matrix_vector_add_mul │   1.83e+05 (1.00)    │ 1.83e+05 (1.00)    │ 1.83e+05 │
│ JuMP.VariableRef │ broadcast_matrix_vector_sub_mul │   1.83e+05 (1.00)    │ 1.83e+05 (1.00)    │ 1.83e+05 │
│ JuMP.VariableRef │        broadcast_vector_add_mul │   1.74e+04 (0.99)    │ 1.77e+04 (1.00)    │ 1.77e+04 │
│ JuMP.VariableRef │         broadcast_vector_square │   1.70e+04 (0.90) -1.70e+04 (0.90) -1.87e+04 │
│ JuMP.VariableRef │        broadcast_vector_sub_mul │   1.74e+04 (0.99)    │ 1.77e+04 (1.00)    │ 1.77e+04 │
│ JuMP.VariableRef │                             dot │   5.33e+03 (0.75) --5.33e+03 (0.75) --7.11e+03 │
│ JuMP.VariableRef │          matrix_lowertriangular │   1.38e+04 (0.95) -1.38e+04 (0.95) -1.45e+04 │
│ JuMP.VariableRef │                matrix_symmetric │   1.59e+04 (0.96)    │ 1.59e+04 (0.96)    │ 1.67e+04 │
│ JuMP.VariableRef │          matrix_uppertriangular │    0.00e+00 (NaN)    │  0.00e+00 (NaN)    │ 0.00e+00 │
│ JuMP.VariableRef │                             mle │    0.00e+00 (NaN)    │  0.00e+00 (NaN)    │ 0.00e+00 │
│ JuMP.VariableRef │                            plus │   1.05e+01 (1.00)    │ 1.05e+01 (1.00)    │ 1.05e+01 │
│ JuMP.VariableRef │                             pow │    0.00e+00 (NaN)    │  0.00e+00 (NaN)    │ 0.00e+00 │
│ JuMP.VariableRef │                summation_affine │ 3.60e+06 (682.07) ++5.28e+03 (1.00)    │ 5.28e+03 │
│ JuMP.VariableRef │     summation_affine_minus_quad │  2.01e+03 (10.65) ++2.01e+02 (1.06) +1.88e+02 │
│ JuMP.VariableRef │      summation_affine_plus_quad │  2.00e+03 (10.64) ++1.94e+02 (1.03)    │ 1.88e+02 │
│ JuMP.VariableRef │      summation_affine_post_mult │ 3.60e+06 (682.18) ++5.28e+03 (1.00)    │ 5.28e+03 │
│ JuMP.VariableRef │               summation_flatten │   1.36e+01 (1.21) ++1.13e+01 (1.00)    │ 1.13e+01 │
│ JuMP.VariableRef │             summation_generator │   1.47e+01 (1.30) ++1.13e+01 (1.00)    │ 1.13e+01 │
│ JuMP.VariableRef │     summation_quad_minus_affine │  2.00e+03 (10.88) ++1.93e+02 (1.05)    │ 1.84e+02 │
│ JuMP.VariableRef │      summation_quad_plus_affine │  2.00e+03 (10.87) ++1.92e+02 (1.04)    │ 1.84e+02 │
│ JuMP.VariableRef │             summation_quadratic │ 5.66e+06 (323.38) ++1.75e+04 (1.00)    │ 1.75e+04 │
│ JuMP.VariableRef │   summation_quadratic_post_mult │ 5.66e+06 (323.46) ++1.75e+04 (1.00)    │ 1.75e+04 │
│ JuMP.VariableRef │                summation_vector │   4.67e+03 (0.81) --4.67e+03 (0.81) --5.74e+03 │
│ JuMP.VariableRef │                  vector_product │   5.33e+03 (1.00)    │ 5.33e+03 (1.00)    │ 5.33e+03 │
└──────────────────┴─────────────────────────────────┴──────────────────────┴────────────────────┴──────────┘

@codecov
Copy link

codecov bot commented Oct 28, 2022

Codecov Report

Base: 83.66% // Head: 84.75% // Increases project coverage by +1.08% 🎉

Coverage data is based on head (f73fddb) compared to base (201aafb).
Patch coverage: 99.30% of modified lines in pull request are covered.

❗ Current head f73fddb differs from pull request most recent head fefb084. Consider uploading reports for the commit fefb084 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #170      +/-   ##
==========================================
+ Coverage   83.66%   84.75%   +1.08%     
==========================================
  Files          20       21       +1     
  Lines        1898     2033     +135     
==========================================
+ Hits         1588     1723     +135     
  Misses        310      310              
Impacted Files Coverage Δ
src/MutableArithmetics.jl 97.22% <ø> (ø)
src/rewrite_generic.jl 99.21% <99.21%> (ø)
src/rewrite.jl 80.39% <100.00%> (+1.20%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

src/new_rewrite.jl Outdated Show resolved Hide resolved
@blegat
Copy link
Member

blegat commented Oct 28, 2022

Noticeably, just evaluating things is faster, likely because of type stability?

That's a bit surprising but as this was never benchmarked properly, we might be doing something inefficient somewhere indeed.

@odow
Copy link
Member Author

odow commented Oct 28, 2022

we might be doing something inefficient somewhere indeed.

I need to check with JuMP types. It's probably much more efficient if the cost of creating a new object is high.

@odow

This comment was marked as outdated.

@odow
Copy link
Member Author

odow commented Nov 6, 2022

This PR also closes #44

julia> MA2.@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)
1

julia> MA.@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)
ERROR: ArgumentError: reducing over an empty collection is not allowed
julia> @macroexpand MA2.@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)
quote
    #= /Users/oscar/.julia/dev/MutableArithmetics/src/new_rewrite.jl:33 =#
    let
        #= /Users/oscar/.julia/dev/MutableArithmetics/src/new_rewrite.jl:34 =#
        begin
            #= /Users/oscar/.julia/dev/MutableArithmetics/src/new_rewrite.jl:30 =#
            var"#9192###2317" = MutableArithmetics.Zero()
            for i = 1:0
                #= /Users/oscar/.julia/dev/MutableArithmetics/src/new_rewrite.jl:233 =#
                var"#9192###2317" = (MutableArithmetics.operate!!)(MutableArithmetics.MutableArithmetics2.:+, var"#9192###2317", 1)
            end
            var"#9193###2318" = 1 ^ 2
            var"#9194###2316" = (MutableArithmetics.operate!!)(MutableArithmetics.add_mul, MutableArithmetics.Zero(), var"#9192###2317", var"#9193###2318")
            var"#9195###2319" = 1 .+ var"#9194###2316"
        end
        #= /Users/oscar/.julia/dev/MutableArithmetics/src/new_rewrite.jl:35 =#
        var"#9195###2319"
    end
end

@odow
Copy link
Member Author

odow commented Nov 7, 2022

The affine_post_mul benchmark is interesting. The problem is that

@rewrite(sum(i * x[i] for i in 1:n) * 2)
# becomes
y = MA.Zero()
for i in 1:n
    y = MA.operate!!(MA.add_mul, y, i, x[i])
end
y = MA.operate!!(MA.add_mul, MA.Zero(), y, 2)

So the last operation allocates a new expression.

Ideally, we'd do something like

y = MA.operate!!(*, y, 2)

but this isn't what MA.@rewrite currently does.

Edit: I've done this and updated the benchmarks

@odow odow force-pushed the od/rewrite-2 branch 4 times, most recently from 42f7017 to 4e01cde Compare November 8, 2022 00:15
@odow
Copy link
Member Author

odow commented Nov 8, 2022

I think it's safe to say that this is now even slightly more efficient than the existing MA.@rewrite. The code is also a lot simpler to understand.

stats = DataFrames.combine(
    DataFrames.groupby(df, [:type, :method]),
    :result => Statistics.mean => :mean,
    :result => Statistics.median => :median,
    :result => (x -> exp(Statistics.mean(log.(x)))) => :geomean,
)

julia> DataFrames.sort!(stats, [:method, :type])
15×5 DataFrame
 Row │ type              method  mean       median    geomean  
     │ String            String  Float64    Float64   Float64  
─────┼─────────────────────────────────────────────────────────
   1 │ BigFloat          IM       0.644393  0.543036  0.574129
   2 │ BigInt            IM       0.805905  0.617568  0.662454
   3 │ Float64           IM       0.801565  0.901345  0.698718
   4 │ Int64             IM       0.781128  0.955712  0.688505
   5 │ JuMP.VariableRef  IM      65.746     1.19753   5.23297
   6 │ BigFloat          MA       1.0       1.0       1.0
   7 │ BigInt            MA       1.0       1.0       1.0
   8 │ Float64           MA       1.0       1.0       1.0
   9 │ Int64             MA       1.0       1.0       1.0
  10 │ JuMP.VariableRef  MA       1.0       1.0       1.0
  11 │ BigFloat          MA2      0.798462  0.96      0.759529
  12 │ BigInt            MA2      0.80857   0.977416  0.76787
  13 │ Float64           MA2      0.965268  0.997299  0.959214
  14 │ Int64             MA2      1.0056    0.989867  0.998947
  15 │ JuMP.VariableRef  MA2      0.970589  0.987654  0.968822

The main decision is whether to add it as a new feature, or whether to replace the existing MA.@rewrite.

@odow
Copy link
Member Author

odow commented Nov 8, 2022

So one major thing this PR doesn't do is rewrite broadcasts. I need to:

  • Find a benchmark where this matters
  • Consider adding

@odow
Copy link
Member Author

odow commented Nov 8, 2022

@blegat can you find the benchmarks that demonstrated the need to rewrite broadcasts?

See current benchmarks. broadcast_matrix_vector_add_mul is 2x slower. Fixed

@blegat
Copy link
Member

blegat commented Nov 10, 2022

I don't see that broadcast_matrix_vector_add_mul is slower anymore. This is looking good. I am tempted to say we could just replace the existing rewrite and tag v1.1. The most worrying at the moment seems to be summation_affine_post_mult for which IM is much faster than MA2 for BigInt and BigFloat.

@odow
Copy link
Member Author

odow commented Nov 10, 2022

I don't see that broadcast_matrix_vector_add_mul is slower anymore

Yes. I fixed this by rewriting :.+ and :.-.

@odow
Copy link
Member Author

odow commented Nov 11, 2022

I am tempted to say we could just replace the existing rewrite and tag v1.1

The changes break JuMP's tests by changing printing and the error types that are thrown, so I'd be in favor of keeping it as a new feature.

@ccoffrin
Copy link

The changes break JuMP's tests by changing printing and the error types that are thrown, so I'd be in favor of keeping it as a new feature.

If that is the case and given that MA2 seems strictly an improvement of MA(?), why not do a full replace of MA and tag version v2.0, then JuMP can update to the new version in the next release. Unless there is still some uncertainty around the design or possible drawbacks of MA2?

@odow
Copy link
Member Author

odow commented Nov 14, 2022

Releasing v2.0 seems a bit drastic. There are 748 dependent packages, so releasing a breaking change has cost: https://juliahub.com/ui/Packages/MutableArithmetics/EoEec/1.0.5.

Perhaps we could have rewrite(expr; assume_sums_are_linear::Bool = true), and then opting out can use the new rewrite rules.

possible drawbacks of MA2?

Main drawbacks are a lack of testing. This is the problem with Julia's interface design; it's hard to tell which methods we are missing without running examples. We know that the current design works for people; we can't be 100% confident that the new design also works for the same use-cases.

@ccoffrin
Copy link

Releasing v2.0 seems a bit drastic. There are 748 dependent packages, so releasing a breaking change has cost: https://juliahub.com/ui/Packages/MutableArithmetics/EoEec/1.0.5.

Wow, I see your point. I had no idea this package was so widely used.

src/rewrite.jl Outdated Show resolved Hide resolved
@odow odow changed the title WIP: rewrite experiments Add new rewrite that does not make strong assumptions on result type Nov 14, 2022
Fix rewrite with views

Update

Support broadcasting

Fix handling of *

Fix generators

Add tests for repeated sums

Fix formatting

Updates

Place new rewrite behind opt-in kwarg

More coverage

Update docstrings
@odow
Copy link
Member Author

odow commented Nov 15, 2022

Okay, I think I'm happy for this to be merged. It's strictly a feature addition hidden behind an opt-in flag.

@odow
Copy link
Member Author

odow commented Nov 16, 2022

Hmm. Found a few more edge cases that we need to fix before merging.

julia> using JuMP

julia> using LinearAlgebra

julia> model = Model()
A JuMP Model
Feasibility problem with:
Variables: 0
Model mode: AUTOMATIC
CachingOptimizer state: NO_OPTIMIZER
Solver name: No optimizer attached.

julia> @variable(model, x[1:3])
3-element Vector{VariableRef}:
 x[1]
 x[2]
 x[3]

julia> A = LowerTriangular(rand(3, 3))
3×3 LowerTriangular{Float64, Matrix{Float64}}:
 0.507834             
 0.178216  0.774911    
 0.498208  0.438567  0.165303

julia> const MA = JuMP._MA
MutableArithmetics

julia> MA.@rewrite(A * x, move_factors_into_sums = true)
3-element Vector{AffExpr}:
 0.5078344563174482 x[1]
 0.1782159773716443 x[1] + 0.7749111979573089 x[2]
 0.49820804791789297 x[1] + 0.4385670932233323 x[2] + 0.16530329346630457 x[3]

julia> MA.@rewrite(A * x, move_factors_into_sums = false)
ERROR: MethodError: Cannot `convert` an object of type QuadExpr to an object of type AffExpr
Closest candidates are:
  convert(::Type{GenericAffExpr{T, V}}, ::GenericAffExpr{T, V}) where {T, V} at /Users/oscar/.julia/packages/JuMP/Z1pVn/src/aff_expr.jl:538
  convert(::Type{GenericAffExpr{T, V}}, ::GenericAffExpr{S, V}) where {S, T, V} at /Users/oscar/.julia/packages/JuMP/Z1pVn/src/aff_expr.jl:545
  convert(::Type{GenericAffExpr{T, V}}, ::Union{Number, UniformScaling}) where {T, V} at /Users/oscar/.julia/packages/JuMP/Z1pVn/src/aff_expr.jl:534
  ...
Stacktrace:
 [1] setindex!
   @ ./array.jl:845 [inlined]
 [2] lmul!(A::LowerTriangular{AffExpr, Matrix{AffExpr}}, B::Vector{AffExpr})
   @ LinearAlgebra /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/triangular.jl:951
 [3] *(A::LowerTriangular{Float64, Matrix{Float64}}, B::Vector{VariableRef})
   @ LinearAlgebra /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/triangular.jl:1657
 [4] top-level scope
   @ ~/.julia/dev/MutableArithmetics/src/rewrite.jl:317

@odow
Copy link
Member Author

odow commented Nov 16, 2022

By not doing fancy rewrites, this now fixes (or at least, improves) the examples from #66 as well:

julia> using JuMP

julia> using LinearAlgebra

julia> model = Model()
A JuMP Model
Feasibility problem with:
Variables: 0
Model mode: AUTOMATIC
CachingOptimizer state: NO_OPTIMIZER
Solver name: No optimizer attached.

julia> @variable(model, x[1:3])
3-element Vector{VariableRef}:
 x[1]
 x[2]
 x[3]

julia> A = LowerTriangular(rand(3, 3))
3×3 LowerTriangular{Float64, Matrix{Float64}}:
 0.840885             
 0.944976  0.178466    
 0.440437  0.252618  0.637874

julia> c = [1, 2, 3]
3-element Vector{Int64}:
 1
 2
 3

julia> const MA = JuMP._MA
MutableArithmetics

julia> MA.@rewrite(A * x, move_factors_into_sums = true)
3-element Vector{AffExpr}:
 0.8408850607032472 x[1]
 0.944975752029531 x[1] + 0.17846559026575726 x[2]
 0.44043661564033143 x[1] + 0.25261817336168857 x[2] + 0.6378741155843315 x[3]

julia> MA.@rewrite(A * x, move_factors_into_sums = false)
3-element Vector{AffExpr}:
 0.8408850607032472 x[1]
 0.944975752029531 x[1] + 0.17846559026575726 x[2]
 0.44043661564033143 x[1] + 0.25261817336168857 x[2] + 0.6378741155843315 x[3]

julia> MA.@rewrite((A * x) .* c, move_factors_into_sums = false)
3-element Vector{AffExpr}:
 0.8408850607032472 x[1]
 1.889951504059062 x[1] + 0.3569311805315145 x[2]
 1.3213098469209943 x[1] + 0.7578545200850657 x[2] + 1.9136223467529945 x[3]
julia> model = Model()
A JuMP Model
Feasibility problem with:
Variables: 0
Model mode: AUTOMATIC
CachingOptimizer state: NO_OPTIMIZER
Solver name: No optimizer attached.

julia> @variable(model, x[1:2])
2-element Vector{VariableRef}:
 x[1]
 x[2]

julia> @variable(model, y[1:2, 1:2])
2×2 Matrix{VariableRef}:
 y[1,1]  y[1,2]
 y[2,1]  y[2,2]

julia> A = ones(2, 2)
2×2 Matrix{Float64}:
 1.0  1.0
 1.0  1.0

julia> MA.@rewrite(y - Diagonal(1 .- x) * A, move_factors_into_sums = true)
ERROR: MethodError: Cannot `convert` an object of type QuadExpr to an object of type AffExpr
Closest candidates are:
  convert(::Type{GenericAffExpr{T, V}}, ::GenericAffExpr{T, V}) where {T, V} at /Users/oscar/.julia/packages/JuMP/Z1pVn/src/aff_expr.jl:538
  convert(::Type{GenericAffExpr{T, V}}, ::GenericAffExpr{S, V}) where {S, T, V} at /Users/oscar/.julia/packages/JuMP/Z1pVn/src/aff_expr.jl:545
  convert(::Type{GenericAffExpr{T, V}}, ::Union{Number, UniformScaling}) where {T, V} at /Users/oscar/.julia/packages/JuMP/Z1pVn/src/aff_expr.jl:534
  ...
Stacktrace:
  [1] setindex!
    @ ./array.jl:845 [inlined]
  [2] setindex!
    @ ./multidimensional.jl:645 [inlined]
  [3] macro expansion
    @ ./broadcast.jl:984 [inlined]
  [4] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [5] copyto!
    @ ./broadcast.jl:983 [inlined]
  [6] copyto!
    @ ./broadcast.jl:936 [inlined]
  [7] materialize!
    @ ./broadcast.jl:894 [inlined]
  [8] materialize!
    @ ./broadcast.jl:891 [inlined]
  [9] lmul!(D::Diagonal{AffExpr, Vector{AffExpr}}, B::Matrix{AffExpr})
    @ LinearAlgebra /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/diagonal.jl:212
 [10] *(D::Diagonal{AffExpr, Vector{AffExpr}}, A::Matrix{Float64})
    @ LinearAlgebra /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/diagonal.jl:201
 [11] sub_mul
    @ ~/.julia/dev/MutableArithmetics/src/MutableArithmetics.jl:39 [inlined]
 [12] operate
    @ ~/.julia/dev/MutableArithmetics/src/interface.jl:199 [inlined]
 [13] operate_fallback!!
    @ ~/.julia/dev/MutableArithmetics/src/interface.jl:571 [inlined]
 [14] operate!!(::typeof(MutableArithmetics.sub_mul), ::Matrix{VariableRef}, ::Diagonal{AffExpr, Vector{AffExpr}}, ::Matrix{Float64})
    @ MutableArithmetics ~/.julia/dev/MutableArithmetics/src/rewrite.jl:89
 [15] top-level scope
    @ ~/.julia/dev/MutableArithmetics/src/rewrite.jl:322

julia> MA.@rewrite(y - Diagonal(1 .- x) * A, move_factors_into_sums = false)
2×2 Matrix{AffExpr}:
 y[1,1] + x[1] - 1  y[1,2] + x[1] - 1
 y[2,1] + x[2] - 1  y[2,2] + x[2] - 1

@odow
Copy link
Member Author

odow commented Nov 17, 2022

Since this is opt-in, perhaps we should merge this, and then continue to tweak in follow-up PRs.

@blegat
Copy link
Member

blegat commented Nov 21, 2022

I agree, let's merge so that we can see changes to this more easily

@odow odow merged commit c715a0c into master Nov 21, 2022
@odow odow deleted the od/rewrite-2 branch November 21, 2022 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

Rewrite of sum()*constant is suboptimal Empty sums
3 participants