Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow Hilbert-Schmidt inner product #234

Closed
araujoms opened this issue Oct 26, 2023 · 12 comments · Fixed by #248
Closed

Slow Hilbert-Schmidt inner product #234

araujoms opened this issue Oct 26, 2023 · 12 comments · Fixed by #248

Comments

@araujoms
Copy link
Contributor

When computing the Hilbert-Schmidt inner product of two matrices a,b it is much faster to use dot(a,b) than the definition tr(a*b), as this avoids a matrix multiplication. I've benchmarked it numerically to confirm, and the difference is orders of magnitude. However, it doesn't work with JuMP variables, for some reason dot(a,b) is the slower one!

I've tried replacing dot(a,b) with dot(vec(a),vec(b)), and it does makes things better. For small dimensions it is still slower than tr(a*b), but as we get to dimensions where the speed matters (like 200) it gets orders of magnitude faster. So there is a workaround already, but such a low-level hack shouldn't be necessary, dot should just do the right thing.

Here is the code I used for benchmarking:

using JuMP
using LinearAlgebra

function random_matrix(d)
    x = randn(Float64, (d,d))
    return Symmetric(x+x')
end

function timejump(d)
    model = JuMP.Model()
    a = random_matrix(d)
    b = JuMP.@variable(model, [1:d,1:d] in JuMP.PSDCone())
    display(@elapsed dot(a,b))
    display(@elapsed tr(a*b))
end

function timejumpvec(d)
    model = JuMP.Model()
    a = random_matrix(d)
    b = JuMP.@variable(model, [1:d,1:d] in JuMP.PSDCone())
    display(@elapsed dot(vec(a),vec(b)))
    display(@elapsed tr(a*b))
end
@blegat
Copy link
Member

blegat commented Oct 26, 2023

It's slow because it doesn't get dispatched to

function LinearAlgebra.dot(
lhs::AbstractArray,
rhs::AbstractArray{<:AbstractMutable},
)
return operate(LinearAlgebra.dot, lhs, rhs)
end

but instead to
https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/symmetric.jl#L458
You can do

operate(dot, a, b)

to be sure to have the fast method called. Having to redirect the methods is a bit tricky because we need to be more specialized than all methods implemented in LinearAlgebra

@araujoms
Copy link
Contributor Author

Thanks, that does indeed work. And rather ironically dot(vec(a),vec(b)) was calling the slow version, operate(dot,vec(a),vec(b))) goes to MutableArithmetics and is faster. Of course, operate(dot,a,b) is much faster than both as it probably uses the proper algorithm for symmetric matrices. Something is going wrong with the complex case, though. operate(dot,a,b) is still faster than tr(a*b), as it must, but it is slower than both vectorized versions.

In the meanwhile, I was trying to implement a dot function for JuMP variables using the following warning for guidance: "The addition operator has been used on JuMP expressions a large number of times. This warning is safe to ignore but may indicate that model generation is slower than necessary. For performance reasons, you should not add expressions in a loop. Instead of x += y, use add_to_expression!(x,y) to modify x in place. If y is a single variable, you may also use add_to_expression!(x, coef, y) for x += coef*y."

I assume it's rather outdated, as add_to_expression! just gives an error, and it doesn't mention that the proper function is already implemented elsewhere.

In any case, since we already have the proper method implemented we need to actually use it. What's the difficulty with the dispatch? Can't we just define something like Base.dot{T<:AbstractJuMPScalar}(lhs::AbstractMatrix{T},rhs::AbstractMatrix) and the necessary variants?

@blegat
Copy link
Member

blegat commented Oct 26, 2023

If you get that warning, it means you are not dispatched to the right one in MutableArithmetics.

What's the difficulty with the dispatch? Can't we just define something like Base.dot{T<:AbstractJuMPScalar}(lhs::AbstractMatrix{T},rhs::AbstractMatrix) and the necessary variants?

The issue is that the necessary variants is a long list and it changes at every Julia release. Now that the stdlib will be decoupled, it might be more feasible but still. The list is even longer in SparseArrays than LinearAlgebra last time I checked

@odow
Copy link
Member

odow commented Oct 26, 2023

Moving this to MutableArithmetics

@odow odow transferred this issue from jump-dev/JuMP.jl Oct 26, 2023
@odow
Copy link
Member

odow commented Oct 26, 2023

I guess we could add the Symmetric case here:

function LinearAlgebra.dot(
lhs::AbstractArray{<:AbstractMutable},
rhs::AbstractArray,
)
return operate(LinearAlgebra.dot, lhs, rhs)
end
function LinearAlgebra.dot(
lhs::AbstractArray,
rhs::AbstractArray{<:AbstractMutable},
)
return operate(LinearAlgebra.dot, lhs, rhs)
end
function LinearAlgebra.dot(
lhs::AbstractArray{<:AbstractMutable},
rhs::AbstractArray{<:AbstractMutable},
)
return operate(LinearAlgebra.dot, lhs, rhs)
end

@odow
Copy link
Member

odow commented Oct 26, 2023

So we could make the dot(::Symmetric, ::Symmetric) case work (three more methods), but I don't know if we want to add the dot(::Symmetric, ::AbstractMatrix) and dot(::AbstractMatrix, ::Symmetric) as well? That's 9 methods just for this one case, and we still aren't going to cover all possibilities.

Instead of dot, we could suggest people use sum(a .* b)...

@araujoms
Copy link
Contributor Author

I see, so it's already dispatching correctly when we call dot with vectors or full matrices, only Symmetric and Hermitian matrices are missing. I can do it myself if that's the hold up. The problem with sum(a .* b) is that it's again an order of magnitude slower than the proper algorithm.

About the warning, the problem is that add_to_expression! doesn't actually work, so if I had done something like

for i=1:n
    x += y
end

I would get the warning, try to fix it in the recommended way, and fail.

@araujoms
Copy link
Contributor Author

In LinearAlgebra they simply used a for to cover all possibilities without copy-pasting code: https://github.com/JuliaLang/julia/blob/75881a9edd8a2d4b46664f27fae98c6ec16f6e38/stdlib/LinearAlgebra/src/symmetric.jl#L456-L497

@odow
Copy link
Member

odow commented Oct 26, 2023

is that it's again an order of magnitude slower than the proper algorithm.

Is this a meaningful bottleneck though?

In LinearAlgebra they simply used a for to cover all possibilities without copy-pasting code

Sure, we can use codegen to simplify things, but there's a cost (in precompilation size and time) to having a large number of methods in MutableArithmetics.

@araujoms
Copy link
Contributor Author

is that it's again an order of magnitude slower than the proper algorithm.

Is this a meaningful bottleneck though?

With tr(a'*b) it is. With sum(conj(a) .* b) it would be challenging to write a program where that's the bottleneck. A more relevant issue is the readability of the code. dot(a,b) is so much nicer.

In LinearAlgebra they simply used a for to cover all possibilities without copy-pasting code

Sure, we can use codegen to simplify things, but there's a cost (in precompilation size and time) to having a large number of methods in MutableArithmetics.

Is this a meaningful bottleneck though?

@odow
Copy link
Member

odow commented Oct 26, 2023

With tr(a'*b) it is

Can you give examples of runtime?

Is this a meaningful bottleneck though?

MutableArithmetics has >900 dependent packages: https://juliahub.com/ui/Packages/General/MutableArithmetics

The increase in loading time affects everyone. Your example might be only a second or two slower for one particular model.

@araujoms
Copy link
Contributor Author

With the function

function timejump(d)
    model = JuMP.Model()
    a = random_matrix(d)
    b = JuMP.@variable(model, [1:d,1:d] in JuMP.PSDCone())
    display(@elapsed operate(dot,a,b))
    display(@elapsed tr(a*b))
    display(@elapsed dot(a,b))
end

I get

julia> timejump(500)
0.038824475
16.632690687
42.862116192

julia> timejump(800)
0.100039467

In the latter case I couldn't even get a time for tr(a*b) and dot(a,b), the computer just freezes because of lack of RAM (I have 12 GiB).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants