From 2f4008fd235beb8d3272f3a779d9edcaac4d2d3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Ara=C3=BAjo?= Date: Wed, 8 May 2024 20:08:44 +0200 Subject: [PATCH] Eliminate allocations in relative entropy cone (#841) --- src/Cones/arrayutilities.jl | 11 ++-- src/Cones/epitrrelentropytri.jl | 102 ++++++++++++++++++++------------ 2 files changed, 71 insertions(+), 42 deletions(-) diff --git a/src/Cones/arrayutilities.jl b/src/Cones/arrayutilities.jl index 036951ddf..84ba6d889 100644 --- a/src/Cones/arrayutilities.jl +++ b/src/Cones/arrayutilities.jl @@ -382,16 +382,17 @@ function eig_dot_kron!( @assert issymmetric(inner) # must be symmetric (wrapper is less efficient) rt2i = inv(rt2) copyto!(V, vecs') # allows fast column slices - V_views = [view(V, :, i) for i in 1:size(inner, 1)] scals = (R <: Complex{T} ? [rt2i, rt2i * im] : [rt2i]) # real and imag parts col_idx = 1 - @inbounds for (j, V_j) in enumerate(V_views) + @inbounds for j in 1:size(inner, 1) + @views V_j = V[:, j] for i in 1:(j - 1), scal in scals - mul!(temp1, V_j, V_views[i]', scal, false) + @views V_i = V[:, i] + mul!(temp1, V_j, V_i', scal, false) @. temp2 = inner * (temp1 + temp1') mul!(temp1, Hermitian(temp2, :U), V) - mul!(temp2, V', temp1) + mul!(temp2, vecs, temp1) @views smat_to_svec!(skr[:, col_idx], temp2, rt2) col_idx += 1 end @@ -399,7 +400,7 @@ function eig_dot_kron!( mul!(temp2, V_j, V_j') temp2 .*= inner mul!(temp1, Hermitian(temp2, :U), V) - mul!(temp2, V', temp1) + mul!(temp2, vecs, temp1) @views smat_to_svec!(skr[:, col_idx], temp2, rt2) col_idx += 1 end diff --git a/src/Cones/epitrrelentropytri.jl b/src/Cones/epitrrelentropytri.jl index 0e097cf24..49a89f1e4 100644 --- a/src/Cones/epitrrelentropytri.jl +++ b/src/Cones/epitrrelentropytri.jl @@ -69,6 +69,16 @@ mutable struct EpiTrRelEntropyTri{T <: Real, R <: RealOrComplex{T}} <: Cone{T} mat3::Matrix{R} mat4::Matrix{R} Wsim_Δ3::Array{R, 3} + V_dir_sim::Matrix{R} + W_dir_sim::Matrix{R} + VW_dir_sim::Matrix{R} + diff_dot_V_VV::Matrix{R} + diff_dot_V_VW::Matrix{R} + diff_dot_W_WW::Matrix{R} + Vvd::Vector{T} + Wwd::Vector{T} + VWwd::Vector{T} + VWvd::Vector{T} function EpiTrRelEntropyTri{T, R}( dim::Int; @@ -130,6 +140,16 @@ function setup_extra_data!( cone.mat3 = zeros(R, d, d) cone.mat4 = zeros(R, d, d) cone.Wsim_Δ3 = zeros(R, d, d, d) + cone.V_dir_sim = zeros(R, d, d) + cone.W_dir_sim = zeros(R, d, d) + cone.VW_dir_sim = zeros(R, d, d) + cone.diff_dot_V_VV = zeros(R, d, d) + cone.diff_dot_V_VW = zeros(R, d, d) + cone.diff_dot_W_WW = zeros(R, d, d) + cone.Vvd = zeros(T, vw_dim) + cone.Wwd = zeros(T, vw_dim) + cone.VWwd = zeros(T, vw_dim) + cone.VWvd = zeros(T, vw_dim) return end @@ -206,8 +226,11 @@ function update_grad( mat3 = cone.mat3 g = cone.grad - spectral_outer!(Vi, V_vecs, inv.(V_λ), cone.mat) - spectral_outer!(Wi, W_vecs, inv.(W_λ), cone.mat) + @views tempvec = cone.tempvec[1:(cone.d)] + @. tempvec = inv(V_λ) + spectral_outer!(Vi, V_vecs, tempvec, cone.mat) + @. tempvec = inv(W_λ) + spectral_outer!(Wi, W_vecs, tempvec, cone.mat) g[1] = -zi @@ -371,6 +394,7 @@ function dder3( dir::AbstractVector{T}, ) where {T <: Real, R <: RealOrComplex{T}} cone.dder3_aux_updated || update_dder3_aux(cone) + dder3 = cone.dder3 rt2 = cone.rt2 V_idxs = cone.V_idxs W_idxs = cone.W_idxs @@ -383,57 +407,56 @@ function dder3( dzdW = cone.dzdW mat = cone.mat mat2 = cone.mat2 - dder3 = cone.dder3 + V_dir_sim = cone.V_dir_sim + W_dir_sim = cone.W_dir_sim + VW_dir_sim = cone.VW_dir_sim + diff_dot_V_VV = cone.diff_dot_V_VV + diff_dot_V_VW = cone.diff_dot_V_VW + diff_dot_W_WW = cone.diff_dot_W_WW + Vvd = cone.Vvd + Wwd = cone.Wwd + VWwd = cone.VWwd + VWvd = cone.VWvd u_dir = dir[1] @views v_dir = dir[V_idxs] @views w_dir = dir[W_idxs] # v, w - # TODO prealloc - V_dir = Hermitian(zero(mat), :U) - W_dir = Hermitian(zero(mat), :U) - V_dir_sim = zero(mat) - W_dir_sim = zero(mat) - VW_dir_sim = zero(mat) - W_part_1 = zero(mat) - V_part_1 = zero(mat) - d3WlogVdV = zero(mat) - diff_dot_V_VV = zero(mat) - diff_dot_V_VW = zero(mat) - diff_dot_W_WW = zero(mat) - - svec_to_smat!(V_dir.data, v_dir, rt2) - svec_to_smat!(W_dir.data, w_dir, rt2) - spectral_outer!(V_dir_sim, V_vecs', V_dir, mat) - spectral_outer!(W_dir_sim, W_vecs', W_dir, mat) - spectral_outer!(VW_dir_sim, V_vecs', W_dir, mat) + V_dir = svec_to_smat!(mat, v_dir, rt2) + spectral_outer!(V_dir_sim, V_vecs', Hermitian(V_dir, :U), mat2) + + W_dir = svec_to_smat!(mat, w_dir, rt2) + spectral_outer!(W_dir_sim, W_vecs', Hermitian(W_dir, :U), mat2) + spectral_outer!(VW_dir_sim, V_vecs', Hermitian(W_dir, :U), mat2) for k in 1:(cone.d) @views mul!(mat2[:, k], cone.Wsim_Δ3[:, :, k], V_dir_sim[:, k]) end + @. mat = mat2 + mat2' spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2) - Vvd = smat_to_svec!(zero(w_dir), mat, rt2) + smat_to_svec!(Vvd, mat, rt2) @. mat = W_dir_sim * cone.Δ2_W spectral_outer!(mat, W_vecs, Hermitian(mat, :U), mat2) - Wwd = smat_to_svec!(zero(w_dir), mat, rt2) + smat_to_svec!(Wwd, mat, rt2) @. mat = VW_dir_sim * cone.Δ2_V spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2) - VWwd = smat_to_svec!(zero(w_dir), mat, rt2) + smat_to_svec!(VWwd, mat, rt2) @. mat = V_dir_sim * cone.Δ2_V spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2) - VWvd = smat_to_svec!(zero(w_dir), mat, rt2) + smat_to_svec!(VWvd, mat, rt2) const0 = zi * u_dir + dot(v_dir, dzdV) + dot(w_dir, dzdW) const1 = abs2(const0) + zi * (-dot(v_dir, VWwd) + (-dot(v_dir, Vvd) + dot(w_dir, Wwd)) / 2) - V_part_1a = -const0 * (Vvd + VWwd) - W_part_1a = Wwd - VWvd + Vvd .+= VWwd + V_part_1a = Vvd .*= -const0 + W_part_1a = Wwd .-= VWvd # u dder3[1] = zi * const1 @@ -457,9 +480,11 @@ function dder3( end # v - d3WlogVdV!(d3WlogVdV, Δ3_V, V_λ, V_dir_sim, cone.W_sim, mat) - svec_to_smat!(V_part_1, V_part_1a, rt2) - @. V_dir_sim /= sqrt(V_λ)' + d3WlogVdV = d3WlogVdV!(mat2, Δ3_V, V_λ, V_dir_sim, cone.W_sim, mat) + V_part_1 = svec_to_smat!(cone.mat3, V_part_1a, rt2) + @views tempvec = cone.tempvec[1:(cone.d)] + tempvec .= sqrt.(V_λ) + @. V_dir_sim /= tempvec' ldiv!(Diagonal(V_λ), V_dir_sim) V_part_2 = d3WlogVdV @. V_part_2 += diff_dot_V_VW + diff_dot_V_VW' @@ -471,10 +496,12 @@ function dder3( @. dder3_V += const1 * dzdV # w - svec_to_smat!(W_part_1, W_part_1a, rt2) + W_part_1 = svec_to_smat!(V_part_1, W_part_1a, rt2) spectral_outer!(mat2, V_vecs, Hermitian(diff_dot_V_VV, :U), mat) axpby!(true, mat2, const0, W_part_1) - @. W_dir_sim /= sqrt(W_λ)' + @views tempvec = cone.tempvec[1:(cone.d)] + tempvec .= sqrt.(W_λ) + @. W_dir_sim /= tempvec' ldiv!(Diagonal(W_λ), W_dir_sim) W_part_2 = diff_dot_W_WW mul!(W_part_2, W_dir_sim, W_dir_sim', true, -zi) @@ -551,31 +578,32 @@ function d2zdV2!( ) where {T <: Real, R <: RealOrComplex{T}} d = size(vecs, 1) V = copyto!(mat, vecs') - V_views = [view(V, :, i) for i in 1:d] rt2i = inv(rt2) scals = (R <: Complex{T} ? [rt2i, rt2i * im] : [rt2i]) col_idx = 1 @inbounds for j in 1:d + @views V_j = V[:, j] for i in 1:(j - 1), scal in scals - mul!(mat3, V_views[j], V_views[i]', scal, false) + @views V_i = V[:, i] + mul!(mat3, V_j, V_i', scal, false) @. mat2 = mat3 + mat3' for k in 1:d @views mul!(mat3[:, k], Wsim_Δ3[:, :, k], mat2[:, k]) end # mat2 = vecs * (mat3 + mat3) * vecs' @. mat2 = mat3 + mat3' - spectral_outer!(mat2, V', Hermitian(mat2, :U), mat3) + spectral_outer!(mat2, vecs, Hermitian(mat2, :U), mat3) @views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2) col_idx += 1 end - mul!(mat2, V_views[j], V_views[j]') + mul!(mat2, V_j, V_j') for k in 1:d @views mul!(mat3[:, k], Wsim_Δ3[:, :, k], mat2[:, k]) end @. mat2 = mat3 + mat3' - spectral_outer!(mat2, V', Hermitian(mat2, :U), mat3) + spectral_outer!(mat2, vecs, Hermitian(mat2, :U), mat3) @views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2) col_idx += 1 end