Skip to content

Commit

Permalink
use safe spectral_outer! instead of copying
Browse files Browse the repository at this point in the history
  • Loading branch information
araujoms committed Apr 11, 2024
1 parent be17235 commit fe312b1
Showing 1 changed file with 16 additions and 28 deletions.
44 changes: 16 additions & 28 deletions src/Cones/epitrrelentropytri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ mutable struct EpiTrRelEntropyTri{T <: Real, R <: RealOrComplex{T}} <: Cone{T}
W::Matrix{R}
V_fact::Eigen{R}
W_fact::Eigen{R}
V_vecs_adj::Matrix{R}
W_vecs_adj::Matrix{R}
Vi::Matrix{R}
Wi::Matrix{R}
W_sim::Matrix{R}
Expand Down Expand Up @@ -109,8 +107,6 @@ function setup_extra_data!(
cone.W_idxs = (vw_dim + 2):(cone.dim)
cone.V = zeros(R, d, d)
cone.W = zeros(R, d, d)
cone.V_vecs_adj = zeros(R, d, d)
cone.W_vecs_adj = zeros(R, d, d)
cone.Vi = zeros(R, d, d)
cone.Wi = zeros(R, d, d)
cone.W_sim = zeros(R, d, d)
Expand Down Expand Up @@ -200,8 +196,6 @@ function update_grad(
zi = inv(cone.z)
(V_λ, V_vecs) = cone.V_fact
(W_λ, W_vecs) = cone.W_fact
cone.V_vecs_adj = Matrix(V_vecs')
cone.W_vecs_adj = Matrix(W_vecs')
Vi = cone.Vi
Wi = cone.Wi
W_sim = cone.W_sim
Expand All @@ -226,7 +220,7 @@ function update_grad(

Δ2!(Δ2_V, V_λ, cone.V_λ_log)

spectral_outer!(W_sim, cone.V_vecs_adj, Hermitian(cone.W, :U), mat)
spectral_outer!(W_sim, V_vecs', Hermitian(cone.W, :U), mat)
@. mat = W_sim * Δ2_V
spectral_outer!(mat2, V_vecs, Hermitian(mat, :U), mat3)
@. mat2 *= zi
Expand Down Expand Up @@ -278,7 +272,7 @@ function update_hess(cone::EpiTrRelEntropyTri)
@. H[1, W_idxs] = zi * dzdW

# vv
d2zdV2!(d2zdV2, V_vecs, cone.V_vecs_adj, cone.Wsim_Δ3, mat, mat2, mat3, rt2)
d2zdV2!(d2zdV2, V_vecs, cone.Wsim_Δ3, mat, mat2, mat3, rt2)

@. d2zdV2 *= -1
@views Hvv = H[V_idxs, V_idxs]
Expand Down Expand Up @@ -318,8 +312,6 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiTrRe
arr_W_mat = cone.mat4
(V_λ, V_vecs) = cone.V_fact
(W_λ, W_vecs) = cone.W_fact
V_vecs_adj = cone.V_vecs_adj
W_vecs_adj = cone.W_vecs_adj

@inbounds for i in 1:size(arr, 2)
@views V_arr = arr[V_idxs, i]
Expand All @@ -331,14 +323,14 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiTrRe
@views @. prod[W_idxs, i] = dzdW * const1
# Hwv * a_v
arr_V_mat = svec_to_smat!(temp, V_arr, rt2)
spectral_outer!(Varr_simV, V_vecs_adj, Hermitian(arr_V_mat, :U), temp2)
spectral_outer!(Varr_simV, V_vecs', Hermitian(arr_V_mat, :U), temp2)
@. temp = Varr_simV * cone.Δ2_V
spectral_outer!(temp, V_vecs, Hermitian(temp, :U), temp2)
@. temp /= -z
@views prod[W_idxs, i] .+= smat_to_svec!(tempvec, temp, rt2)
# Hww * a_w
svec_to_smat!(arr_W_mat, W_arr, rt2)
Warr_simW = spectral_outer!(temp2, W_vecs_adj, Hermitian(arr_W_mat, :U), temp)
Warr_simW = spectral_outer!(temp2, W_vecs', Hermitian(arr_W_mat, :U), temp)
@. temp = Warr_simW * cone.Δ2_W / z
@. Warr_simW /= W_λ'
ldiv!(Diagonal(W_λ), Warr_simW)
Expand All @@ -353,7 +345,7 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiTrRe
end
@. temp = temp2 + temp2'
# overwrites arr_W_mat
Warr_simV = spectral_outer!(arr_W_mat, V_vecs_adj, Hermitian(arr_W_mat, :U), temp2)
Warr_simV = spectral_outer!(arr_W_mat, V_vecs', Hermitian(arr_W_mat, :U), temp2)
@. temp += Warr_simV * cone.Δ2_V
@. temp /= -z
@. Varr_simV /= V_λ'
Expand Down Expand Up @@ -392,8 +384,6 @@ function dder3(
mat = cone.mat
mat2 = cone.mat2
dder3 = cone.dder3
V_vecs_adj = cone.V_vecs_adj
W_vecs_adj = cone.W_vecs_adj

u_dir = dir[1]
@views v_dir = dir[V_idxs]
Expand All @@ -415,9 +405,9 @@ function dder3(

svec_to_smat!(V_dir.data, v_dir, rt2)
svec_to_smat!(W_dir.data, w_dir, rt2)
spectral_outer!(V_dir_sim, V_vecs_adj, V_dir, mat)
spectral_outer!(W_dir_sim, W_vecs_adj, W_dir, mat)
spectral_outer!(VW_dir_sim, V_vecs_adj, W_dir, mat)
spectral_outer!(V_dir_sim, V_vecs', V_dir, mat)
spectral_outer!(W_dir_sim, W_vecs', W_dir, mat)
spectral_outer!(VW_dir_sim, V_vecs', W_dir, mat)

for k in 1:(cone.d)
@views mul!(mat2[:, k], cone.Wsim_Δ3[:, :, k], V_dir_sim[:, k])
Expand Down Expand Up @@ -474,8 +464,8 @@ function dder3(
V_part_2 = d3WlogVdV
@. V_part_2 += diff_dot_V_VW + diff_dot_V_VW'
mul!(V_part_2, V_dir_sim, V_dir_sim', true, zi)
mul!(mat, Hermitian(V_part_2, :U), V_vecs_adj)
mul!(V_part_1, V_vecs, mat, true, zi)
mul!(mat, V_vecs, Hermitian(V_part_2, :U))
mul!(V_part_1, mat, V_vecs', true, zi)
@views dder3_V = dder3[V_idxs]
smat_to_svec!(dder3_V, V_part_1, rt2)
@. dder3_V += const1 * dzdV
Expand All @@ -488,8 +478,8 @@ function dder3(
ldiv!(Diagonal(W_λ), W_dir_sim)
W_part_2 = diff_dot_W_WW
mul!(W_part_2, W_dir_sim, W_dir_sim', true, -zi)
mul!(mat, Hermitian(W_part_2, :U), W_vecs_adj)
mul!(W_part_1, W_vecs, mat, true, zi)
mul!(mat, W_vecs, Hermitian(W_part_2, :U))
mul!(W_part_1, mat, W_vecs', true, zi)
@views dder3_W = dder3[W_idxs]
smat_to_svec!(dder3_W, W_part_1, rt2)
@. dder3_W += const1 * dzdW
Expand Down Expand Up @@ -553,15 +543,15 @@ end
function d2zdV2!(
d2zdV2::Matrix{T},
vecs::Matrix{R},
vecs_adj::Matrix{R},
Wsim_Δ3::Array{R, 3},
mat::Matrix{R}, # temp
mat2::Matrix{R}, # temp
mat3::Matrix{R}, # temp
rt2::T,
) where {T <: Real, R <: RealOrComplex{T}}
d = size(vecs, 1)
V_views = [view(vecs_adj, :, i) for i in 1:d]
V = copyto!(mat,vecs')
V_views = [view(V, :, i) for i in 1:d]
rt2i = inv(rt2)
scals = (R <: Complex{T} ? [rt2i, rt2i * im] : [rt2i])

Expand All @@ -575,8 +565,7 @@ function d2zdV2!(
end
# mat2 = vecs * (mat3 + mat3) * vecs'
@. mat2 = mat3 + mat3'
mul!(mat3, Hermitian(mat2, :U), vecs_adj)
mul!(mat2, vecs, mat3)
spectral_outer!(mat2, V', Hermitian(mat2, :U), mat3)
@views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2)
col_idx += 1
end
Expand All @@ -586,8 +575,7 @@ function d2zdV2!(
@views mul!(mat3[:, k], Wsim_Δ3[:, :, k], mat2[:, k])
end
@. mat2 = mat3 + mat3'
mul!(mat3, Hermitian(mat2, :U), vecs_adj)
mul!(mat2, vecs, mat3)
spectral_outer!(mat2, V', Hermitian(mat2, :U), mat3)
@views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2)
col_idx += 1
end
Expand Down

0 comments on commit fe312b1

Please sign in to comment.