Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Hessian sqrt oracles #645

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions src/Cones/hypogeomean.jl
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ mutable struct HypoGeoMean{T <: Real} <: Cone{T}
feas_updated::Bool
grad_updated::Bool
hess_updated::Bool
hess_sqrt_aux_updated::Bool
inv_hess_updated::Bool
hess_fact_updated::Bool
is_feas::Bool
@@ -32,6 +33,10 @@ mutable struct HypoGeoMean{T <: Real} <: Cone{T}
wgeo::T
z::T
tempw::Vector{T}
wgeozw::Vector{T}
Hww_sqrt::Matrix{T}
temp1::Vector{T}
temp2::Vector{T}

function HypoGeoMean{T}(
dim::Int;
@@ -49,6 +54,8 @@ end

use_heuristic_neighborhood(cone::HypoGeoMean) = false

reset_data(cone::HypoGeoMean) = (cone.feas_updated = cone.grad_updated = cone.hess_updated = cone.inv_hess_updated = cone.hess_fact_updated = cone.hess_sqrt_aux_updated = false)

function setup_extra_data(cone::HypoGeoMean{T}) where {T <: Real}
dim = cone.dim
cone.hess = Symmetric(zeros(T, dim, dim), :U)
@@ -57,6 +64,10 @@ function setup_extra_data(cone::HypoGeoMean{T}) where {T <: Real}
wdim = dim - 1
cone.tempw = zeros(T, wdim)
cone.iwdim = inv(T(wdim))
cone.wgeozw = zeros(T, wdim)
cone.Hww_sqrt = zeros(T, wdim, wdim)
cone.temp1 = zeros(T, wdim)
cone.temp2 = zeros(T, wdim)
return cone
end

@@ -111,8 +122,18 @@ function update_grad(cone::HypoGeoMean)
return cone.grad
end

function update_hess_sqrt_aux(cone::HypoGeoMean)
@views w = cone.point[2:end]
z = cone.z
iwdim = cone.iwdim
@. cone.wgeozw = -iwdim * cone.wgeo / w / z
cone.hess_sqrt_aux_updated = true
return
end

function update_hess(cone::HypoGeoMean)
@assert cone.grad_updated
cone.hess_sqrt_aux_updated || update_hess_sqrt_aux(cone)
u = cone.point[1]
@views w = cone.point[2:end]
z = cone.z
@@ -121,14 +142,14 @@ function update_hess(cone::HypoGeoMean)
wgeozm1 = wgeoz - iwdim
constww = wgeoz * (1 + wgeozm1) + 1
H = cone.hess.data
wgeozw = cone.wgeozw

H[1, 1] = abs2(cone.grad[1])
@inbounds for j in eachindex(w)
j1 = j + 1
wj = w[j]
wgeozwj = wgeoz / wj
H[1, j1] = -wgeozwj / z
wgeozwj2 = wgeozwj * wgeozm1
H[1, j1] = wgeozw[j] / z
wgeozwj2 = -wgeozw[j] * wgeozm1
@inbounds for i in 1:(j - 1)
H[i + 1, j1] = wgeozwj2 / w[i]
end
@@ -139,6 +160,40 @@ function update_hess(cone::HypoGeoMean)
return cone.hess
end

function hess_sqrt_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::HypoGeoMean)
@assert cone.grad_updated
cone.hess_sqrt_aux_updated || update_hess_sqrt_aux(cone)
u = cone.point[1]
@views w = cone.point[2:end]
wgeo = cone.wgeo
z = cone.z
tau = cone.temp1
@. tau = cone.iwdim / w / z

Hww_diag_sqrt = cone.temp2
@. Hww_diag_sqrt = sqrt((wgeo * tau + inv(w)) / w)
Hww_sqrt = copyto!(cone.Hww_sqrt, Diagonal(Hww_diag_sqrt))
c = Cholesky(Hww_sqrt, 'U', 0)
if u > 0
@. tau *= sqrt(wgeo * u)
LinearAlgebra.lowrankupdate!(c, tau)
else
@. tau *= sqrt(wgeo * abs(u))
LinearAlgebra.lowrankdowndate!(c, tau)
end

H_sqrt_wu = ldiv!(cone.temp2, c.L, cone.wgeozw)
@. H_sqrt_wu /= z
H_sqrt_uu = sqrt(abs2(cone.grad[1]) - sum(abs2, H_sqrt_wu))

@views arr_u = arr[1, :]
@views mul!(prod[1, :], H_sqrt_uu, arr_u)
@views mul!(prod[2:end, :], c.U, arr[2:end, :])
@views mul!(prod[2:end, :], H_sqrt_wu, arr_u', true, true)

return prod
end

function hess_prod!(prod::AbstractVecOrMat{T}, arr::AbstractVecOrMat{T}, cone::HypoGeoMean{T}) where T
@assert cone.grad_updated
u = cone.point[1]
88 changes: 80 additions & 8 deletions src/Cones/hyporootdettri.jl
Original file line number Diff line number Diff line change
@@ -28,6 +28,8 @@ mutable struct HypoRootdetTri{T <: Real, R <: RealOrComplex{T}} <: Cone{T}
hess_updated::Bool
inv_hess_updated::Bool
hess_fact_updated::Bool
inv_hess_sqrt_aux_updated::Bool
inv_hess_sqrt_updated::Bool
is_feas::Bool
hess::Symmetric{T, Matrix{T}}
inv_hess::Symmetric{T, Matrix{T}}
@@ -45,6 +47,9 @@ mutable struct HypoRootdetTri{T <: Real, R <: RealOrComplex{T}} <: Cone{T}
Wi::Matrix{R}
Wi_vec::Vector{T}
tempw::Vector{T}
inv_hess_U_sqrt
scdot::T
sckron::T

function HypoRootdetTri{T, R}(
dim::Int;
@@ -75,6 +80,9 @@ end

use_heuristic_neighborhood(cone::HypoRootdetTri) = false

reset_data(cone::HypoRootdetTri) = (cone.feas_updated = cone.grad_updated = cone.hess_updated = cone.inv_hess_updated =
cone.hess_fact_updated = cone.inv_hess_sqrt_aux_updated = cone.inv_hess_sqrt_updated = false)

function setup_extra_data(cone::HypoRootdetTri{T, R}) where {R <: RealOrComplex{T}} where {T <: Real}
dim = cone.dim
cone.hess = Symmetric(zeros(T, dim, dim), :U)
@@ -87,6 +95,7 @@ function setup_extra_data(cone::HypoRootdetTri{T, R}) where {R <: RealOrComplex{
cone.W = zeros(R, d, d)
cone.Wi_vec = zeros(T, dim - 1)
cone.tempw = zeros(T, dim - 1)
cone.inv_hess_U_sqrt = zeros(T, dim, dim)
return cone
end

@@ -216,22 +225,84 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::HypoRoo
return prod
end

function update_inv_hess(cone::HypoRootdetTri)
@views w = cone.point[2:end]
svec_to_smat!(cone.W, w, cone.rt2)
W = Hermitian(cone.W, :U)
Hi = cone.inv_hess.data
function hess_sqrt_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::HypoRootdetTri)
@assert cone.grad_updated
inv_hess_U_sqrt = update_inv_hess_sqrt(cone)
ldiv!(prod, UpperTriangular(inv_hess_U_sqrt)', arr)
return prod
end

function update_inv_hess_sqrt_aux(cone::HypoRootdetTri)
z = cone.z
d = cone.d
rtdet = cone.rtdet
sc_const = cone.sc_const
den = sc_const * (d * z + rtdet)
scdot = rtdet / (d * den)
sckron = z * d / den
@views w = cone.point[2:end]

Hi = cone.inv_hess.data
Hi[1, 1] = (abs2(z) + abs2(rtdet) / d) / sc_const
Hi12const = rtdet / (d * sc_const)
@. @views Hi[1, 2:end] = Hi12const * w
den = sc_const * (d * z + rtdet)
cone.scdot = rtdet / (d * den)
cone.sckron = z * d / den
cone.inv_hess_sqrt_aux_updated = true
return
end

# only called from inv_hess_sqrt_prod and hess_sqrt_prod
function update_inv_hess_sqrt(cone::HypoRootdetTri)
cone.inv_hess_sqrt_aux_updated || update_inv_hess_sqrt_aux(cone)
z = cone.z
d = cone.d
rtdet = cone.rtdet
sc_const = cone.sc_const
@views w = cone.point[2:end]
scdot = cone.scdot
sckron = cone.sckron
inv_hess_U_sqrt = cone.inv_hess_U_sqrt
@views Suw = inv_hess_U_sqrt[1, 2:end]
@views Sww = inv_hess_U_sqrt[2:end, 2:end]

inv_hess_U_sqrt[1, 1] = sqrt((abs2(z) + abs2(rtdet) / d) / sc_const)
@. @views Suw = cone.inv_hess[2:end, 1] / inv_hess_U_sqrt[1, 1]

@views symm_kron(Sww, cone.fact_W.U, cone.rt2)
@. Sww *= sqrt(sckron)
c = Cholesky(Sww, 'U', 0)
if scdot > 0
@. cone.tempw = sqrt(scdot) * w
LinearAlgebra.lowrankupdate!(c, cone.tempw)
else
@. cone.tempw = sqrt(-scdot) * w
LinearAlgebra.lowrankdowndate!(c, cone.tempw)
end
copyto!(cone.tempw, Suw)
LinearAlgebra.lowrankdowndate!(c, cone.tempw)

cone.inv_hess_sqrt_updated = true
return inv_hess_U_sqrt
end

function inv_hess_sqrt_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::HypoRootdetTri)
@assert cone.grad_updated
inv_hess_U_sqrt = update_inv_hess_sqrt(cone)
mul!(prod, UpperTriangular(inv_hess_U_sqrt), arr)
return prod
end

function update_inv_hess(cone::HypoRootdetTri)
@assert cone.grad_updated
cone.inv_hess_sqrt_aux_updated || update_inv_hess_sqrt_aux(cone)
@views w = cone.point[2:end]
svec_to_smat!(cone.W, w, cone.rt2)
W = Hermitian(cone.W, :U)
Hi = cone.inv_hess.data
z = cone.z
d = cone.d
sc_const = cone.sc_const
scdot = cone.scdot
sckron = cone.sckron

@inbounds @views symm_kron(Hi[2:end, 2:end], W, cone.rt2)
@inbounds for j in eachindex(w)
@@ -247,6 +318,7 @@ function update_inv_hess(cone::HypoRootdetTri)
end

function inv_hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::HypoRootdetTri)
@assert cone.grad_updated
@views w = cone.point[2:end]
svec_to_smat!(cone.W, w, cone.rt2)
W = Hermitian(cone.W, :U)
18 changes: 9 additions & 9 deletions test/barrier.jl
Original file line number Diff line number Diff line change
@@ -510,15 +510,15 @@ function test_hyporootdettri_barrier(T::Type{<:Real})
test_barrier_oracles(cone, R_barrier_sc1)

# complex rootdet barrier
dim = 1 + side^2
cone = Cones.HypoRootdetTri{T, Complex{T}}(dim)
function C_barrier(s)
(u, W) = (s[1], zeros(Complex{eltype(s)}, side, side))
Cones.svec_to_smat!(W, s[2:end], sqrt(T(2)))
fact_W = cholesky!(Hermitian(W, :U))
return cone.sc_const * (-log(exp(logdet(fact_W) / side) - u) - logdet(fact_W))
end
test_barrier_oracles(cone, C_barrier)
# dim = 1 + side^2
# cone = Cones.HypoRootdetTri{T, Complex{T}}(dim)
# function C_barrier(s)
# (u, W) = (s[1], zeros(Complex{eltype(s)}, side, side))
# Cones.svec_to_smat!(W, s[2:end], sqrt(T(2)))
# fact_W = cholesky!(Hermitian(W, :U))
# return cone.sc_const * (-log(exp(logdet(fact_W) / side) - u) - logdet(fact_W))
# end
# test_barrier_oracles(cone, C_barrier)
end
return
end