Skip to content

Commit

Permalink
Merge branch 'jump-dev:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
araujoms authored May 9, 2024
2 parents 423c1ab + 2f4008f commit 6a4cfea
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 42 deletions.
11 changes: 6 additions & 5 deletions src/Cones/arrayutilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,24 +382,25 @@ function eig_dot_kron!(
@assert issymmetric(inner) # must be symmetric (wrapper is less efficient)
rt2i = inv(rt2)
copyto!(V, vecs') # allows fast column slices
V_views = [view(V, :, i) for i in 1:size(inner, 1)]
scals = (R <: Complex{T} ? [rt2i, rt2i * im] : [rt2i]) # real and imag parts

col_idx = 1
@inbounds for (j, V_j) in enumerate(V_views)
@inbounds for j in 1:size(inner, 1)
@views V_j = V[:, j]
for i in 1:(j - 1), scal in scals
mul!(temp1, V_j, V_views[i]', scal, false)
@views V_i = V[:, i]
mul!(temp1, V_j, V_i', scal, false)
@. temp2 = inner * (temp1 + temp1')
mul!(temp1, Hermitian(temp2, :U), V)
mul!(temp2, V', temp1)
mul!(temp2, vecs, temp1)
@views smat_to_svec!(skr[:, col_idx], temp2, rt2)
col_idx += 1
end

mul!(temp2, V_j, V_j')
temp2 .*= inner
mul!(temp1, Hermitian(temp2, :U), V)
mul!(temp2, V', temp1)
mul!(temp2, vecs, temp1)
@views smat_to_svec!(skr[:, col_idx], temp2, rt2)
col_idx += 1
end
Expand Down
102 changes: 65 additions & 37 deletions src/Cones/epitrrelentropytri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ mutable struct EpiTrRelEntropyTri{T <: Real, R <: RealOrComplex{T}} <: Cone{T}
mat3::Matrix{R}
mat4::Matrix{R}
Wsim_Δ3::Array{R, 3}
V_dir_sim::Matrix{R}
W_dir_sim::Matrix{R}
VW_dir_sim::Matrix{R}
diff_dot_V_VV::Matrix{R}
diff_dot_V_VW::Matrix{R}
diff_dot_W_WW::Matrix{R}
Vvd::Vector{T}
Wwd::Vector{T}
VWwd::Vector{T}
VWvd::Vector{T}

function EpiTrRelEntropyTri{T, R}(
dim::Int;
Expand Down Expand Up @@ -130,6 +140,16 @@ function setup_extra_data!(
cone.mat3 = zeros(R, d, d)
cone.mat4 = zeros(R, d, d)
cone.Wsim_Δ3 = zeros(R, d, d, d)
cone.V_dir_sim = zeros(R, d, d)
cone.W_dir_sim = zeros(R, d, d)
cone.VW_dir_sim = zeros(R, d, d)
cone.diff_dot_V_VV = zeros(R, d, d)
cone.diff_dot_V_VW = zeros(R, d, d)
cone.diff_dot_W_WW = zeros(R, d, d)
cone.Vvd = zeros(T, vw_dim)
cone.Wwd = zeros(T, vw_dim)
cone.VWwd = zeros(T, vw_dim)
cone.VWvd = zeros(T, vw_dim)
return
end

Expand Down Expand Up @@ -206,8 +226,11 @@ function update_grad(
mat3 = cone.mat3
g = cone.grad

spectral_outer!(Vi, V_vecs, inv.(V_λ), cone.mat)
spectral_outer!(Wi, W_vecs, inv.(W_λ), cone.mat)
@views tempvec = cone.tempvec[1:(cone.d)]
@. tempvec = inv(V_λ)
spectral_outer!(Vi, V_vecs, tempvec, cone.mat)
@. tempvec = inv(W_λ)
spectral_outer!(Wi, W_vecs, tempvec, cone.mat)

g[1] = -zi

Expand Down Expand Up @@ -371,6 +394,7 @@ function dder3(
dir::AbstractVector{T},
) where {T <: Real, R <: RealOrComplex{T}}
cone.dder3_aux_updated || update_dder3_aux(cone)
dder3 = cone.dder3
rt2 = cone.rt2
V_idxs = cone.V_idxs
W_idxs = cone.W_idxs
Expand All @@ -383,57 +407,56 @@ function dder3(
dzdW = cone.dzdW
mat = cone.mat
mat2 = cone.mat2
dder3 = cone.dder3
V_dir_sim = cone.V_dir_sim
W_dir_sim = cone.W_dir_sim
VW_dir_sim = cone.VW_dir_sim
diff_dot_V_VV = cone.diff_dot_V_VV
diff_dot_V_VW = cone.diff_dot_V_VW
diff_dot_W_WW = cone.diff_dot_W_WW
Vvd = cone.Vvd
Wwd = cone.Wwd
VWwd = cone.VWwd
VWvd = cone.VWvd

u_dir = dir[1]
@views v_dir = dir[V_idxs]
@views w_dir = dir[W_idxs]

# v, w
# TODO prealloc
V_dir = Hermitian(zero(mat), :U)
W_dir = Hermitian(zero(mat), :U)
V_dir_sim = zero(mat)
W_dir_sim = zero(mat)
VW_dir_sim = zero(mat)
W_part_1 = zero(mat)
V_part_1 = zero(mat)
d3WlogVdV = zero(mat)
diff_dot_V_VV = zero(mat)
diff_dot_V_VW = zero(mat)
diff_dot_W_WW = zero(mat)

svec_to_smat!(V_dir.data, v_dir, rt2)
svec_to_smat!(W_dir.data, w_dir, rt2)
spectral_outer!(V_dir_sim, V_vecs', V_dir, mat)
spectral_outer!(W_dir_sim, W_vecs', W_dir, mat)
spectral_outer!(VW_dir_sim, V_vecs', W_dir, mat)
V_dir = svec_to_smat!(mat, v_dir, rt2)
spectral_outer!(V_dir_sim, V_vecs', Hermitian(V_dir, :U), mat2)

W_dir = svec_to_smat!(mat, w_dir, rt2)
spectral_outer!(W_dir_sim, W_vecs', Hermitian(W_dir, :U), mat2)
spectral_outer!(VW_dir_sim, V_vecs', Hermitian(W_dir, :U), mat2)

for k in 1:(cone.d)
@views mul!(mat2[:, k], cone.Wsim_Δ3[:, :, k], V_dir_sim[:, k])
end

@. mat = mat2 + mat2'
spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2)
Vvd = smat_to_svec!(zero(w_dir), mat, rt2)
smat_to_svec!(Vvd, mat, rt2)

@. mat = W_dir_sim * cone.Δ2_W
spectral_outer!(mat, W_vecs, Hermitian(mat, :U), mat2)
Wwd = smat_to_svec!(zero(w_dir), mat, rt2)
smat_to_svec!(Wwd, mat, rt2)

@. mat = VW_dir_sim * cone.Δ2_V
spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2)
VWwd = smat_to_svec!(zero(w_dir), mat, rt2)
smat_to_svec!(VWwd, mat, rt2)

@. mat = V_dir_sim * cone.Δ2_V
spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2)
VWvd = smat_to_svec!(zero(w_dir), mat, rt2)
smat_to_svec!(VWvd, mat, rt2)

const0 = zi * u_dir + dot(v_dir, dzdV) + dot(w_dir, dzdW)
const1 =
abs2(const0) + zi * (-dot(v_dir, VWwd) + (-dot(v_dir, Vvd) + dot(w_dir, Wwd)) / 2)

V_part_1a = -const0 * (Vvd + VWwd)
W_part_1a = Wwd - VWvd
Vvd .+= VWwd
V_part_1a = Vvd .*= -const0
W_part_1a = Wwd .-= VWvd

# u
dder3[1] = zi * const1
Expand All @@ -457,9 +480,11 @@ function dder3(
end

# v
d3WlogVdV!(d3WlogVdV, Δ3_V, V_λ, V_dir_sim, cone.W_sim, mat)
svec_to_smat!(V_part_1, V_part_1a, rt2)
@. V_dir_sim /= sqrt(V_λ)'
d3WlogVdV = d3WlogVdV!(mat2, Δ3_V, V_λ, V_dir_sim, cone.W_sim, mat)
V_part_1 = svec_to_smat!(cone.mat3, V_part_1a, rt2)
@views tempvec = cone.tempvec[1:(cone.d)]
tempvec .= sqrt.(V_λ)
@. V_dir_sim /= tempvec'
ldiv!(Diagonal(V_λ), V_dir_sim)
V_part_2 = d3WlogVdV
@. V_part_2 += diff_dot_V_VW + diff_dot_V_VW'
Expand All @@ -471,10 +496,12 @@ function dder3(
@. dder3_V += const1 * dzdV

# w
svec_to_smat!(W_part_1, W_part_1a, rt2)
W_part_1 = svec_to_smat!(V_part_1, W_part_1a, rt2)
spectral_outer!(mat2, V_vecs, Hermitian(diff_dot_V_VV, :U), mat)
axpby!(true, mat2, const0, W_part_1)
@. W_dir_sim /= sqrt(W_λ)'
@views tempvec = cone.tempvec[1:(cone.d)]
tempvec .= sqrt.(W_λ)
@. W_dir_sim /= tempvec'
ldiv!(Diagonal(W_λ), W_dir_sim)
W_part_2 = diff_dot_W_WW
mul!(W_part_2, W_dir_sim, W_dir_sim', true, -zi)
Expand Down Expand Up @@ -551,31 +578,32 @@ function d2zdV2!(
) where {T <: Real, R <: RealOrComplex{T}}
d = size(vecs, 1)
V = copyto!(mat, vecs')
V_views = [view(V, :, i) for i in 1:d]
rt2i = inv(rt2)
scals = (R <: Complex{T} ? [rt2i, rt2i * im] : [rt2i])

col_idx = 1
@inbounds for j in 1:d
@views V_j = V[:, j]
for i in 1:(j - 1), scal in scals
mul!(mat3, V_views[j], V_views[i]', scal, false)
@views V_i = V[:, i]
mul!(mat3, V_j, V_i', scal, false)
@. mat2 = mat3 + mat3'
for k in 1:d
@views mul!(mat3[:, k], Wsim_Δ3[:, :, k], mat2[:, k])
end
# mat2 = vecs * (mat3 + mat3) * vecs'
@. mat2 = mat3 + mat3'
spectral_outer!(mat2, V', Hermitian(mat2, :U), mat3)
spectral_outer!(mat2, vecs, Hermitian(mat2, :U), mat3)
@views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2)
col_idx += 1
end

mul!(mat2, V_views[j], V_views[j]')
mul!(mat2, V_j, V_j')
for k in 1:d
@views mul!(mat3[:, k], Wsim_Δ3[:, :, k], mat2[:, k])
end
@. mat2 = mat3 + mat3'
spectral_outer!(mat2, V', Hermitian(mat2, :U), mat3)
spectral_outer!(mat2, vecs, Hermitian(mat2, :U), mat3)
@views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2)
col_idx += 1
end
Expand Down

0 comments on commit 6a4cfea

Please sign in to comment.