Skip to content

Commit

Permalink
cleanup mul! with Hermitian and Adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
araujoms committed Apr 10, 2024
1 parent 4659a7e commit 32f45bc
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions src/Cones/epitrrelentropytri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ mutable struct EpiTrRelEntropyTri{T <: Real, R <: RealOrComplex{T}} <: Cone{T}
W::Matrix{R}
V_fact::Eigen{R}
W_fact::Eigen{R}
V_vecs_adj::Matrix{R}
W_vecs_adj::Matrix{R}
Vi::Matrix{R}
Wi::Matrix{R}
W_sim::Matrix{R}
Expand Down Expand Up @@ -107,6 +109,8 @@ function setup_extra_data!(
cone.W_idxs = (vw_dim + 2):(cone.dim)
cone.V = zeros(R, d, d)
cone.W = zeros(R, d, d)
cone.V_vecs_adj = zeros(R, d, d)
cone.W_vecs_adj = zeros(R, d, d)
cone.Vi = zeros(R, d, d)
cone.Wi = zeros(R, d, d)
cone.W_sim = zeros(R, d, d)
Expand Down Expand Up @@ -166,7 +170,6 @@ function update_feas(
VH = Hermitian(cone.V, :U)
WH = Hermitian(cone.W, :U)
if isposdef(VH) && isposdef(WH)
# TODO use LAPACK syev! instead of syevr! for efficiency
V_fact = cone.V_fact = eigen(VH)
W_fact = cone.W_fact = eigen(WH)
if isposdef(V_fact) && isposdef(W_fact)
Expand Down Expand Up @@ -197,6 +200,8 @@ function update_grad(
zi = inv(cone.z)
(V_λ, V_vecs) = cone.V_fact
(W_λ, W_vecs) = cone.W_fact
cone.V_vecs_adj = Matrix(V_vecs')
cone.W_vecs_adj = Matrix(W_vecs')
Vi = cone.Vi
Wi = cone.Wi
W_sim = cone.W_sim
Expand All @@ -221,7 +226,7 @@ function update_grad(

Δ2!(Δ2_V, V_λ, cone.V_λ_log)

spectral_outer!(W_sim, V_vecs', Hermitian(cone.W, :U), mat)
spectral_outer!(W_sim, cone.V_vecs_adj, Hermitian(cone.W, :U), mat)
@. mat = W_sim * Δ2_V
spectral_outer!(mat2, V_vecs, Hermitian(mat, :U), mat3)
@. mat2 *= zi
Expand Down Expand Up @@ -273,7 +278,7 @@ function update_hess(cone::EpiTrRelEntropyTri)
@. H[1, W_idxs] = zi * dzdW

# vv
d2zdV2!(d2zdV2, V_vecs, cone.Wsim_Δ3, mat, mat2, mat3, rt2)
d2zdV2!(d2zdV2, V_vecs, cone.V_vecs_adj, cone.Wsim_Δ3, mat, mat2, mat3, rt2)

@. d2zdV2 *= -1
@views Hvv = H[V_idxs, V_idxs]
Expand Down Expand Up @@ -313,6 +318,8 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiTrRe
arr_W_mat = cone.mat4
(V_λ, V_vecs) = cone.V_fact
(W_λ, W_vecs) = cone.W_fact
V_vecs_adj = cone.V_vecs_adj
W_vecs_adj = cone.W_vecs_adj

@inbounds for i in 1:size(arr, 2)
@views V_arr = arr[V_idxs, i]
Expand All @@ -324,14 +331,14 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiTrRe
@views @. prod[W_idxs, i] = dzdW * const1
# Hwv * a_v
arr_V_mat = svec_to_smat!(temp, V_arr, rt2)
spectral_outer!(Varr_simV, V_vecs', Hermitian(arr_V_mat, :U), temp2)
spectral_outer!(Varr_simV, V_vecs_adj, Hermitian(arr_V_mat, :U), temp2)
@. temp = Varr_simV * cone.Δ2_V
spectral_outer!(temp, V_vecs, Hermitian(temp, :U), temp2)
@. temp /= -z
@views prod[W_idxs, i] .+= smat_to_svec!(tempvec, temp, rt2)
# Hww * a_w
svec_to_smat!(arr_W_mat, W_arr, rt2)
Warr_simW = spectral_outer!(temp2, W_vecs', Hermitian(arr_W_mat, :U), temp)
Warr_simW = spectral_outer!(temp2, W_vecs_adj, Hermitian(arr_W_mat, :U), temp)
@. temp = Warr_simW * cone.Δ2_W / z
@. Warr_simW /= W_λ'
ldiv!(Diagonal(W_λ), Warr_simW)
Expand All @@ -345,8 +352,8 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiTrRe
@views mul!(temp2[:, k], cone.Wsim_Δ3[:, :, k], Varr_simV[:, k])
end
@. temp = temp2 + temp2'
# destroys arr_W_mat
Warr_simV = spectral_outer!(arr_W_mat, V_vecs', Hermitian(arr_W_mat, :U), temp2)
# overwrites arr_W_mat
Warr_simV = spectral_outer!(arr_W_mat, V_vecs_adj, Hermitian(arr_W_mat, :U), temp2)
@. temp += Warr_simV * cone.Δ2_V
@. temp /= -z
@. Varr_simV /= V_λ'
Expand Down Expand Up @@ -385,6 +392,8 @@ function dder3(
mat = cone.mat
mat2 = cone.mat2
dder3 = cone.dder3
V_vecs_adj = cone.V_vecs_adj
W_vecs_adj = cone.W_vecs_adj

u_dir = dir[1]
@views v_dir = dir[V_idxs]
Expand All @@ -406,9 +415,9 @@ function dder3(

svec_to_smat!(V_dir.data, v_dir, rt2)
svec_to_smat!(W_dir.data, w_dir, rt2)
spectral_outer!(V_dir_sim, V_vecs', V_dir, mat)
spectral_outer!(W_dir_sim, W_vecs', W_dir, mat)
spectral_outer!(VW_dir_sim, V_vecs', W_dir, mat)
spectral_outer!(V_dir_sim, V_vecs_adj, V_dir, mat)
spectral_outer!(W_dir_sim, W_vecs_adj, W_dir, mat)
spectral_outer!(VW_dir_sim, V_vecs_adj, W_dir, mat)

for k in 1:(cone.d)
@views mul!(mat2[:, k], cone.Wsim_Δ3[:, :, k], V_dir_sim[:, k])
Expand Down Expand Up @@ -465,7 +474,7 @@ function dder3(
V_part_2 = d3WlogVdV
@. V_part_2 += diff_dot_V_VW + diff_dot_V_VW'
mul!(V_part_2, V_dir_sim, V_dir_sim', true, zi)
mul!(mat, Hermitian(V_part_2, :U), V_vecs')
mul!(mat, Hermitian(V_part_2, :U), V_vecs_adj)
mul!(V_part_1, V_vecs, mat, true, zi)
@views dder3_V = dder3[V_idxs]
smat_to_svec!(dder3_V, V_part_1, rt2)
Expand All @@ -479,7 +488,7 @@ function dder3(
ldiv!(Diagonal(W_λ), W_dir_sim)
W_part_2 = diff_dot_W_WW
mul!(W_part_2, W_dir_sim, W_dir_sim', true, -zi)
mul!(mat, Hermitian(W_part_2, :U), W_vecs')
mul!(mat, Hermitian(W_part_2, :U), W_vecs_adj)
mul!(W_part_1, W_vecs, mat, true, zi)
@views dder3_W = dder3[W_idxs]
smat_to_svec!(dder3_W, W_part_1, rt2)
Expand Down Expand Up @@ -544,15 +553,15 @@ end
function d2zdV2!(
d2zdV2::Matrix{T},
vecs::Matrix{R},
vecs_adj::Matrix{R},
Wsim_Δ3::Array{R, 3},
mat::Matrix{R}, # temp
mat2::Matrix{R}, # temp
mat3::Matrix{R}, # temp
rt2::T,
) where {T <: Real, R <: RealOrComplex{T}}
d = size(vecs, 1)
V = copyto!(mat, vecs')
V_views = [view(V, :, i) for i in 1:d]
V_views = [view(vecs_adj, :, i) for i in 1:d]
rt2i = inv(rt2)
scals = (R <: Complex{T} ? [rt2i, rt2i * im] : [rt2i])

Expand All @@ -566,8 +575,8 @@ function d2zdV2!(
end
# mat2 = vecs * (mat3 + mat3) * vecs'
@. mat2 = mat3 + mat3'
mul!(mat3, Hermitian(mat2, :U), V)
mul!(mat2, V', mat3)
mul!(mat3, Hermitian(mat2, :U), vecs_adj)
mul!(mat2, vecs, mat3)
@views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2)
col_idx += 1
end
Expand All @@ -577,8 +586,8 @@ function d2zdV2!(
@views mul!(mat3[:, k], Wsim_Δ3[:, :, k], mat2[:, k])
end
@. mat2 = mat3 + mat3'
mul!(mat3, Hermitian(mat2, :U), V)
mul!(mat2, V', mat3)
mul!(mat3, Hermitian(mat2, :U), vecs_adj)
mul!(mat2, vecs, mat3)
@views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2)
col_idx += 1
end
Expand Down

0 comments on commit 32f45bc

Please sign in to comment.