diff --git a/src/Cones/epitrrelentropytri.jl b/src/Cones/epitrrelentropytri.jl index 0e097cf24..18817cb12 100644 --- a/src/Cones/epitrrelentropytri.jl +++ b/src/Cones/epitrrelentropytri.jl @@ -69,6 +69,16 @@ mutable struct EpiTrRelEntropyTri{T <: Real, R <: RealOrComplex{T}} <: Cone{T} mat3::Matrix{R} mat4::Matrix{R} Wsim_Δ3::Array{R, 3} + V_dir_sim::Matrix{R} + W_dir_sim::Matrix{R} + VW_dir_sim::Matrix{R} + diff_dot_V_VV::Matrix{R} + diff_dot_V_VW::Matrix{R} + diff_dot_W_WW::Matrix{R} + Vvd::Vector{T} + Wwd::Vector{T} + VWwd::Vector{T} + VWvd::Vector{T} function EpiTrRelEntropyTri{T, R}( dim::Int; @@ -130,6 +140,16 @@ function setup_extra_data!( cone.mat3 = zeros(R, d, d) cone.mat4 = zeros(R, d, d) cone.Wsim_Δ3 = zeros(R, d, d, d) + cone.V_dir_sim = zeros(R, d, d) + cone.W_dir_sim = zeros(R, d, d) + cone.VW_dir_sim = zeros(R, d, d) + cone.diff_dot_V_VV = zeros(R, d, d) + cone.diff_dot_V_VW = zeros(R, d, d) + cone.diff_dot_W_WW = zeros(R, d, d) + cone.Vvd = zeros(T, vw_dim) + cone.Wwd = zeros(T, vw_dim) + cone.VWwd = zeros(T, vw_dim) + cone.VWvd = zeros(T, vw_dim) return end @@ -206,8 +226,11 @@ function update_grad( mat3 = cone.mat3 g = cone.grad - spectral_outer!(Vi, V_vecs, inv.(V_λ), cone.mat) - spectral_outer!(Wi, W_vecs, inv.(W_λ), cone.mat) + @views tempvec = cone.tempvec[1:(cone.d)] + @. tempvec = inv(V_λ) + spectral_outer!(Vi, V_vecs, tempvec, cone.mat) + @. tempvec = inv(W_λ) + spectral_outer!(Wi, W_vecs, tempvec, cone.mat) g[1] = -zi @@ -371,6 +394,7 @@ function dder3( dir::AbstractVector{T}, ) where {T <: Real, R <: RealOrComplex{T}} cone.dder3_aux_updated || update_dder3_aux(cone) + dder3 = cone.dder3 rt2 = cone.rt2 V_idxs = cone.V_idxs W_idxs = cone.W_idxs @@ -383,57 +407,57 @@ function dder3( dzdW = cone.dzdW mat = cone.mat mat2 = cone.mat2 - dder3 = cone.dder3 + V_dir_sim = cone.V_dir_sim + W_dir_sim = cone.W_dir_sim + VW_dir_sim = cone.VW_dir_sim + diff_dot_V_VV = cone.diff_dot_V_VV + diff_dot_V_VW = cone.diff_dot_V_VW + diff_dot_W_WW = cone.diff_dot_W_WW + + Vvd = cone.Vvd + Wwd = cone.Wwd + VWwd = cone.VWwd + VWvd = cone.VWvd u_dir = dir[1] @views v_dir = dir[V_idxs] @views w_dir = dir[W_idxs] # v, w - # TODO prealloc - V_dir = Hermitian(zero(mat), :U) - W_dir = Hermitian(zero(mat), :U) - V_dir_sim = zero(mat) - W_dir_sim = zero(mat) - VW_dir_sim = zero(mat) - W_part_1 = zero(mat) - V_part_1 = zero(mat) - d3WlogVdV = zero(mat) - diff_dot_V_VV = zero(mat) - diff_dot_V_VW = zero(mat) - diff_dot_W_WW = zero(mat) - - svec_to_smat!(V_dir.data, v_dir, rt2) - svec_to_smat!(W_dir.data, w_dir, rt2) - spectral_outer!(V_dir_sim, V_vecs', V_dir, mat) - spectral_outer!(W_dir_sim, W_vecs', W_dir, mat) - spectral_outer!(VW_dir_sim, V_vecs', W_dir, mat) + V_dir = svec_to_smat!(mat, v_dir, rt2) + spectral_outer!(V_dir_sim, V_vecs', Hermitian(V_dir, :U), mat2) + + W_dir = svec_to_smat!(mat, w_dir, rt2) + spectral_outer!(W_dir_sim, W_vecs', Hermitian(W_dir, :U), mat2) + spectral_outer!(VW_dir_sim, V_vecs', Hermitian(W_dir, :U), mat2) for k in 1:(cone.d) @views mul!(mat2[:, k], cone.Wsim_Δ3[:, :, k], V_dir_sim[:, k]) end + @. mat = mat2 + mat2' spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2) - Vvd = smat_to_svec!(zero(w_dir), mat, rt2) + smat_to_svec!(Vvd, mat, rt2) @. mat = W_dir_sim * cone.Δ2_W spectral_outer!(mat, W_vecs, Hermitian(mat, :U), mat2) - Wwd = smat_to_svec!(zero(w_dir), mat, rt2) + smat_to_svec!(Wwd, mat, rt2) @. mat = VW_dir_sim * cone.Δ2_V spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2) - VWwd = smat_to_svec!(zero(w_dir), mat, rt2) + smat_to_svec!(VWwd, mat, rt2) @. mat = V_dir_sim * cone.Δ2_V spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2) - VWvd = smat_to_svec!(zero(w_dir), mat, rt2) + smat_to_svec!(VWvd, mat, rt2) const0 = zi * u_dir + dot(v_dir, dzdV) + dot(w_dir, dzdW) const1 = abs2(const0) + zi * (-dot(v_dir, VWwd) + (-dot(v_dir, Vvd) + dot(w_dir, Wwd)) / 2) - V_part_1a = -const0 * (Vvd + VWwd) - W_part_1a = Wwd - VWvd + Vvd .+= VWwd + V_part_1a = Vvd .*= -const0 + W_part_1a = Wwd .-= VWvd # u dder3[1] = zi * const1 @@ -457,9 +481,11 @@ function dder3( end # v - d3WlogVdV!(d3WlogVdV, Δ3_V, V_λ, V_dir_sim, cone.W_sim, mat) - svec_to_smat!(V_part_1, V_part_1a, rt2) - @. V_dir_sim /= sqrt(V_λ)' + d3WlogVdV = d3WlogVdV!(mat2, Δ3_V, V_λ, V_dir_sim, cone.W_sim, mat) + V_part_1 = svec_to_smat!(cone.mat3, V_part_1a, rt2) + @views tempvec = cone.tempvec[1:(cone.d)] + tempvec .= sqrt.(V_λ) + @. V_dir_sim /= tempvec' ldiv!(Diagonal(V_λ), V_dir_sim) V_part_2 = d3WlogVdV @. V_part_2 += diff_dot_V_VW + diff_dot_V_VW' @@ -471,10 +497,12 @@ function dder3( @. dder3_V += const1 * dzdV # w - svec_to_smat!(W_part_1, W_part_1a, rt2) + W_part_1 = svec_to_smat!(V_part_1, W_part_1a, rt2) spectral_outer!(mat2, V_vecs, Hermitian(diff_dot_V_VV, :U), mat) axpby!(true, mat2, const0, W_part_1) - @. W_dir_sim /= sqrt(W_λ)' + @views tempvec = cone.tempvec[1:(cone.d)] + tempvec .= sqrt.(W_λ) + @. W_dir_sim /= tempvec' ldiv!(Diagonal(W_λ), W_dir_sim) W_part_2 = diff_dot_W_WW mul!(W_part_2, W_dir_sim, W_dir_sim', true, -zi)