From 32f45bcf7c09a7a5b61970d1ddd46469d794bef7 Mon Sep 17 00:00:00 2001 From: araujoms Date: Wed, 10 Apr 2024 16:36:04 +0200 Subject: [PATCH] cleanup mul! with Hermitian and Adjoint --- src/Cones/epitrrelentropytri.jl | 45 ++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/src/Cones/epitrrelentropytri.jl b/src/Cones/epitrrelentropytri.jl index 94162e522..f1d930e07 100644 --- a/src/Cones/epitrrelentropytri.jl +++ b/src/Cones/epitrrelentropytri.jl @@ -45,6 +45,8 @@ mutable struct EpiTrRelEntropyTri{T <: Real, R <: RealOrComplex{T}} <: Cone{T} W::Matrix{R} V_fact::Eigen{R} W_fact::Eigen{R} + V_vecs_adj::Matrix{R} + W_vecs_adj::Matrix{R} Vi::Matrix{R} Wi::Matrix{R} W_sim::Matrix{R} @@ -107,6 +109,8 @@ function setup_extra_data!( cone.W_idxs = (vw_dim + 2):(cone.dim) cone.V = zeros(R, d, d) cone.W = zeros(R, d, d) + cone.V_vecs_adj = zeros(R, d, d) + cone.W_vecs_adj = zeros(R, d, d) cone.Vi = zeros(R, d, d) cone.Wi = zeros(R, d, d) cone.W_sim = zeros(R, d, d) @@ -166,7 +170,6 @@ function update_feas( VH = Hermitian(cone.V, :U) WH = Hermitian(cone.W, :U) if isposdef(VH) && isposdef(WH) - # TODO use LAPACK syev! instead of syevr! for efficiency V_fact = cone.V_fact = eigen(VH) W_fact = cone.W_fact = eigen(WH) if isposdef(V_fact) && isposdef(W_fact) @@ -197,6 +200,8 @@ function update_grad( zi = inv(cone.z) (V_λ, V_vecs) = cone.V_fact (W_λ, W_vecs) = cone.W_fact + cone.V_vecs_adj = Matrix(V_vecs') + cone.W_vecs_adj = Matrix(W_vecs') Vi = cone.Vi Wi = cone.Wi W_sim = cone.W_sim @@ -221,7 +226,7 @@ function update_grad( Δ2!(Δ2_V, V_λ, cone.V_λ_log) - spectral_outer!(W_sim, V_vecs', Hermitian(cone.W, :U), mat) + spectral_outer!(W_sim, cone.V_vecs_adj, Hermitian(cone.W, :U), mat) @. mat = W_sim * Δ2_V spectral_outer!(mat2, V_vecs, Hermitian(mat, :U), mat3) @. mat2 *= zi @@ -273,7 +278,7 @@ function update_hess(cone::EpiTrRelEntropyTri) @. H[1, W_idxs] = zi * dzdW # vv - d2zdV2!(d2zdV2, V_vecs, cone.Wsim_Δ3, mat, mat2, mat3, rt2) + d2zdV2!(d2zdV2, V_vecs, cone.V_vecs_adj, cone.Wsim_Δ3, mat, mat2, mat3, rt2) @. d2zdV2 *= -1 @views Hvv = H[V_idxs, V_idxs] @@ -313,6 +318,8 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiTrRe arr_W_mat = cone.mat4 (V_λ, V_vecs) = cone.V_fact (W_λ, W_vecs) = cone.W_fact + V_vecs_adj = cone.V_vecs_adj + W_vecs_adj = cone.W_vecs_adj @inbounds for i in 1:size(arr, 2) @views V_arr = arr[V_idxs, i] @@ -324,14 +331,14 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiTrRe @views @. prod[W_idxs, i] = dzdW * const1 # Hwv * a_v arr_V_mat = svec_to_smat!(temp, V_arr, rt2) - spectral_outer!(Varr_simV, V_vecs', Hermitian(arr_V_mat, :U), temp2) + spectral_outer!(Varr_simV, V_vecs_adj, Hermitian(arr_V_mat, :U), temp2) @. temp = Varr_simV * cone.Δ2_V spectral_outer!(temp, V_vecs, Hermitian(temp, :U), temp2) @. temp /= -z @views prod[W_idxs, i] .+= smat_to_svec!(tempvec, temp, rt2) # Hww * a_w svec_to_smat!(arr_W_mat, W_arr, rt2) - Warr_simW = spectral_outer!(temp2, W_vecs', Hermitian(arr_W_mat, :U), temp) + Warr_simW = spectral_outer!(temp2, W_vecs_adj, Hermitian(arr_W_mat, :U), temp) @. temp = Warr_simW * cone.Δ2_W / z @. Warr_simW /= W_λ' ldiv!(Diagonal(W_λ), Warr_simW) @@ -345,8 +352,8 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiTrRe @views mul!(temp2[:, k], cone.Wsim_Δ3[:, :, k], Varr_simV[:, k]) end @. temp = temp2 + temp2' - # destroys arr_W_mat - Warr_simV = spectral_outer!(arr_W_mat, V_vecs', Hermitian(arr_W_mat, :U), temp2) + # overwrites arr_W_mat + Warr_simV = spectral_outer!(arr_W_mat, V_vecs_adj, Hermitian(arr_W_mat, :U), temp2) @. temp += Warr_simV * cone.Δ2_V @. temp /= -z @. Varr_simV /= V_λ' @@ -385,6 +392,8 @@ function dder3( mat = cone.mat mat2 = cone.mat2 dder3 = cone.dder3 + V_vecs_adj = cone.V_vecs_adj + W_vecs_adj = cone.W_vecs_adj u_dir = dir[1] @views v_dir = dir[V_idxs] @@ -406,9 +415,9 @@ function dder3( svec_to_smat!(V_dir.data, v_dir, rt2) svec_to_smat!(W_dir.data, w_dir, rt2) - spectral_outer!(V_dir_sim, V_vecs', V_dir, mat) - spectral_outer!(W_dir_sim, W_vecs', W_dir, mat) - spectral_outer!(VW_dir_sim, V_vecs', W_dir, mat) + spectral_outer!(V_dir_sim, V_vecs_adj, V_dir, mat) + spectral_outer!(W_dir_sim, W_vecs_adj, W_dir, mat) + spectral_outer!(VW_dir_sim, V_vecs_adj, W_dir, mat) for k in 1:(cone.d) @views mul!(mat2[:, k], cone.Wsim_Δ3[:, :, k], V_dir_sim[:, k]) @@ -465,7 +474,7 @@ function dder3( V_part_2 = d3WlogVdV @. V_part_2 += diff_dot_V_VW + diff_dot_V_VW' mul!(V_part_2, V_dir_sim, V_dir_sim', true, zi) - mul!(mat, Hermitian(V_part_2, :U), V_vecs') + mul!(mat, Hermitian(V_part_2, :U), V_vecs_adj) mul!(V_part_1, V_vecs, mat, true, zi) @views dder3_V = dder3[V_idxs] smat_to_svec!(dder3_V, V_part_1, rt2) @@ -479,7 +488,7 @@ function dder3( ldiv!(Diagonal(W_λ), W_dir_sim) W_part_2 = diff_dot_W_WW mul!(W_part_2, W_dir_sim, W_dir_sim', true, -zi) - mul!(mat, Hermitian(W_part_2, :U), W_vecs') + mul!(mat, Hermitian(W_part_2, :U), W_vecs_adj) mul!(W_part_1, W_vecs, mat, true, zi) @views dder3_W = dder3[W_idxs] smat_to_svec!(dder3_W, W_part_1, rt2) @@ -544,6 +553,7 @@ end function d2zdV2!( d2zdV2::Matrix{T}, vecs::Matrix{R}, + vecs_adj::Matrix{R}, Wsim_Δ3::Array{R, 3}, mat::Matrix{R}, # temp mat2::Matrix{R}, # temp @@ -551,8 +561,7 @@ function d2zdV2!( rt2::T, ) where {T <: Real, R <: RealOrComplex{T}} d = size(vecs, 1) - V = copyto!(mat, vecs') - V_views = [view(V, :, i) for i in 1:d] + V_views = [view(vecs_adj, :, i) for i in 1:d] rt2i = inv(rt2) scals = (R <: Complex{T} ? [rt2i, rt2i * im] : [rt2i]) @@ -566,8 +575,8 @@ function d2zdV2!( end # mat2 = vecs * (mat3 + mat3) * vecs' @. mat2 = mat3 + mat3' - mul!(mat3, Hermitian(mat2, :U), V) - mul!(mat2, V', mat3) + mul!(mat3, Hermitian(mat2, :U), vecs_adj) + mul!(mat2, vecs, mat3) @views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2) col_idx += 1 end @@ -577,8 +586,8 @@ function d2zdV2!( @views mul!(mat3[:, k], Wsim_Δ3[:, :, k], mat2[:, k]) end @. mat2 = mat3 + mat3' - mul!(mat3, Hermitian(mat2, :U), V) - mul!(mat2, V', mat3) + mul!(mat3, Hermitian(mat2, :U), vecs_adj) + mul!(mat2, vecs, mat3) @views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2) col_idx += 1 end