Skip to content

Commit

Permalink
remove allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
araujoms committed May 7, 2024
1 parent 423c1ab commit 010061b
Showing 1 changed file with 60 additions and 32 deletions.
92 changes: 60 additions & 32 deletions src/Cones/epitrrelentropytri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ mutable struct EpiTrRelEntropyTri{T <: Real, R <: RealOrComplex{T}} <: Cone{T}
mat3::Matrix{R}
mat4::Matrix{R}
Wsim_Δ3::Array{R, 3}
V_dir_sim::Matrix{R}
W_dir_sim::Matrix{R}
VW_dir_sim::Matrix{R}
diff_dot_V_VV::Matrix{R}
diff_dot_V_VW::Matrix{R}
diff_dot_W_WW::Matrix{R}
Vvd::Vector{T}
Wwd::Vector{T}
VWwd::Vector{T}
VWvd::Vector{T}

function EpiTrRelEntropyTri{T, R}(
dim::Int;
Expand Down Expand Up @@ -130,6 +140,16 @@ function setup_extra_data!(
cone.mat3 = zeros(R, d, d)
cone.mat4 = zeros(R, d, d)
cone.Wsim_Δ3 = zeros(R, d, d, d)
cone.V_dir_sim = zeros(R, d, d)
cone.W_dir_sim = zeros(R, d, d)
cone.VW_dir_sim = zeros(R, d, d)
cone.diff_dot_V_VV = zeros(R, d, d)
cone.diff_dot_V_VW = zeros(R, d, d)
cone.diff_dot_W_WW = zeros(R, d, d)
cone.Vvd = zeros(T, vw_dim)
cone.Wwd = zeros(T, vw_dim)
cone.VWwd = zeros(T, vw_dim)
cone.VWvd = zeros(T, vw_dim)
return
end

Expand Down Expand Up @@ -206,8 +226,11 @@ function update_grad(
mat3 = cone.mat3
g = cone.grad

spectral_outer!(Vi, V_vecs, inv.(V_λ), cone.mat)
spectral_outer!(Wi, W_vecs, inv.(W_λ), cone.mat)
@views tempvec = cone.tempvec[1:(cone.d)]
@. tempvec = inv(V_λ)
spectral_outer!(Vi, V_vecs, tempvec, cone.mat)
@. tempvec = inv(W_λ)
spectral_outer!(Wi, W_vecs, tempvec, cone.mat)

g[1] = -zi

Expand Down Expand Up @@ -371,6 +394,7 @@ function dder3(
dir::AbstractVector{T},
) where {T <: Real, R <: RealOrComplex{T}}
cone.dder3_aux_updated || update_dder3_aux(cone)
dder3 = cone.dder3
rt2 = cone.rt2
V_idxs = cone.V_idxs
W_idxs = cone.W_idxs
Expand All @@ -383,57 +407,57 @@ function dder3(
dzdW = cone.dzdW
mat = cone.mat
mat2 = cone.mat2
dder3 = cone.dder3
V_dir_sim = cone.V_dir_sim
W_dir_sim = cone.W_dir_sim
VW_dir_sim = cone.VW_dir_sim
diff_dot_V_VV = cone.diff_dot_V_VV
diff_dot_V_VW = cone.diff_dot_V_VW
diff_dot_W_WW = cone.diff_dot_W_WW

Vvd = cone.Vvd
Wwd = cone.Wwd
VWwd = cone.VWwd
VWvd = cone.VWvd

u_dir = dir[1]
@views v_dir = dir[V_idxs]
@views w_dir = dir[W_idxs]

# v, w
# TODO prealloc
V_dir = Hermitian(zero(mat), :U)
W_dir = Hermitian(zero(mat), :U)
V_dir_sim = zero(mat)
W_dir_sim = zero(mat)
VW_dir_sim = zero(mat)
W_part_1 = zero(mat)
V_part_1 = zero(mat)
d3WlogVdV = zero(mat)
diff_dot_V_VV = zero(mat)
diff_dot_V_VW = zero(mat)
diff_dot_W_WW = zero(mat)

svec_to_smat!(V_dir.data, v_dir, rt2)
svec_to_smat!(W_dir.data, w_dir, rt2)
spectral_outer!(V_dir_sim, V_vecs', V_dir, mat)
spectral_outer!(W_dir_sim, W_vecs', W_dir, mat)
spectral_outer!(VW_dir_sim, V_vecs', W_dir, mat)
V_dir = svec_to_smat!(mat, v_dir, rt2)
spectral_outer!(V_dir_sim, V_vecs', Hermitian(V_dir, :U), mat2)

W_dir = svec_to_smat!(mat, w_dir, rt2)
spectral_outer!(W_dir_sim, W_vecs', Hermitian(W_dir, :U), mat2)
spectral_outer!(VW_dir_sim, V_vecs', Hermitian(W_dir, :U), mat2)

for k in 1:(cone.d)
@views mul!(mat2[:, k], cone.Wsim_Δ3[:, :, k], V_dir_sim[:, k])
end

@. mat = mat2 + mat2'
spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2)
Vvd = smat_to_svec!(zero(w_dir), mat, rt2)
smat_to_svec!(Vvd, mat, rt2)

@. mat = W_dir_sim * cone.Δ2_W
spectral_outer!(mat, W_vecs, Hermitian(mat, :U), mat2)
Wwd = smat_to_svec!(zero(w_dir), mat, rt2)
smat_to_svec!(Wwd, mat, rt2)

@. mat = VW_dir_sim * cone.Δ2_V
spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2)
VWwd = smat_to_svec!(zero(w_dir), mat, rt2)
smat_to_svec!(VWwd, mat, rt2)

@. mat = V_dir_sim * cone.Δ2_V
spectral_outer!(mat, V_vecs, Hermitian(mat, :U), mat2)
VWvd = smat_to_svec!(zero(w_dir), mat, rt2)
smat_to_svec!(VWvd, mat, rt2)

const0 = zi * u_dir + dot(v_dir, dzdV) + dot(w_dir, dzdW)
const1 =
abs2(const0) + zi * (-dot(v_dir, VWwd) + (-dot(v_dir, Vvd) + dot(w_dir, Wwd)) / 2)

V_part_1a = -const0 * (Vvd + VWwd)
W_part_1a = Wwd - VWvd
Vvd .+= VWwd
V_part_1a = Vvd .*= -const0
W_part_1a = Wwd .-= VWvd

# u
dder3[1] = zi * const1
Expand All @@ -457,9 +481,11 @@ function dder3(
end

# v
d3WlogVdV!(d3WlogVdV, Δ3_V, V_λ, V_dir_sim, cone.W_sim, mat)
svec_to_smat!(V_part_1, V_part_1a, rt2)
@. V_dir_sim /= sqrt(V_λ)'
d3WlogVdV = d3WlogVdV!(mat2, Δ3_V, V_λ, V_dir_sim, cone.W_sim, mat)
V_part_1 = svec_to_smat!(cone.mat3, V_part_1a, rt2)
@views tempvec = cone.tempvec[1:(cone.d)]
tempvec .= sqrt.(V_λ)
@. V_dir_sim /= tempvec'
ldiv!(Diagonal(V_λ), V_dir_sim)
V_part_2 = d3WlogVdV
@. V_part_2 += diff_dot_V_VW + diff_dot_V_VW'
Expand All @@ -471,10 +497,12 @@ function dder3(
@. dder3_V += const1 * dzdV

# w
svec_to_smat!(W_part_1, W_part_1a, rt2)
W_part_1 = svec_to_smat!(V_part_1, W_part_1a, rt2)
spectral_outer!(mat2, V_vecs, Hermitian(diff_dot_V_VV, :U), mat)
axpby!(true, mat2, const0, W_part_1)
@. W_dir_sim /= sqrt(W_λ)'
@views tempvec = cone.tempvec[1:(cone.d)]
tempvec .= sqrt.(W_λ)
@. W_dir_sim /= tempvec'
ldiv!(Diagonal(W_λ), W_dir_sim)
W_part_2 = diff_dot_W_WW
mul!(W_part_2, W_dir_sim, W_dir_sim', true, -zi)
Expand Down

0 comments on commit 010061b

Please sign in to comment.