Skip to content

Commit

Permalink
improve predorcent stepper and use curve search (#633)
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscoey authored Dec 29, 2020
1 parent fcbf96b commit 7b49c38
Show file tree
Hide file tree
Showing 31 changed files with 785 additions and 976 deletions.
2 changes: 1 addition & 1 deletion examples/common_JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ cone_from_hyp(cone::Cones.EpiNormInf) = (Cones.use_dual_barrier(cone) ? MOI.Norm
cone_from_hyp(cone::Cones.EpiNormEucl) = MOI.SecondOrderCone(Cones.dimension(cone))
cone_from_hyp(cone::Cones.EpiPerSquare) = MOI.RotatedSecondOrderCone(Cones.dimension(cone))
cone_from_hyp(cone::Cones.HypoPerLog) = (@assert Cones.dimension(cone) == 3; MOI.ExponentialCone())
cone_from_hyp(cone::Cones.EpiSumPerEntropy) = MOI.RelativeEntropyCone(Cones.dimension(cone))
cone_from_hyp(cone::Cones.EpiRelEntropy) = MOI.RelativeEntropyCone(Cones.dimension(cone))
cone_from_hyp(cone::Cones.HypoGeoMean) = MOI.GeometricMeanCone(Cones.dimension(cone))
cone_from_hyp(cone::Cones.Power) = (@assert Cones.dimension(cone) == 3; MOI.PowerCone{Float64}(cone.alpha[1]))
cone_from_hyp(cone::Cones.EpiNormSpectral) = (Cones.use_dual_barrier(cone) ? MOI.NormNuclearCone : MOI.NormSpectralCone)(cone.n, cone.m)
Expand Down
2 changes: 1 addition & 1 deletion examples/densityest/native.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ function build(inst::DensityEstNative{T}) where {T <: Real}
G_likl[row_offset + 1, 2:(1 + U)] = -X_pts_polys[i, :]
G_likl[row_offset + 2, ext_offset] = -1
row_offset += 3
push!(cones, Cones.EpiSumPerEntropy{T}(3))
push!(cones, Cones.EpiRelEntropy{T}(3))
end
end
else
Expand Down
4 changes: 2 additions & 2 deletions examples/matrixcompletion/JuMP_benchmark.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

matrixcompletion_insts = [
[(k, d) for d in vcat(3, 10:5:max_d)] # includes compile run
for (k, max_d) in ((5, 60), (10, 45))
[(k, d) for d in vcat(2, 5:5:max_d)] # includes compile run
for (k, max_d) in ((10, 45), (20, 35))
]

insts = Dict()
Expand Down
29 changes: 17 additions & 12 deletions src/Cones/Cones.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ include("nonnegative.jl")
include("epinorminf.jl")
include("epinormeucl.jl")
include("epipersquare.jl")
include("episumperentropy.jl")
include("epitracerelentropytri.jl")
include("epirelentropy.jl")
include("hypoperlog.jl")
include("power.jl")
include("hypopowermean.jl")
Expand All @@ -50,6 +49,7 @@ include("doublynonnegativetri.jl")
include("matrixepipersquare.jl")
include("hypoperlogdettri.jl")
include("hyporootdettri.jl")
include("epitracerelentropytri.jl")
include("wsosinterpnonnegative.jl")
include("wsosinterpepinormone.jl")
include("wsosinterpepinormeucl.jl")
Expand Down Expand Up @@ -190,9 +190,6 @@ function in_neighborhood(cone::Cone{T}, rtmu::T, max_nbhd::T) where {T <: Real}
if use_heuristic_neighborhood(cone)
nbhd = norm(vec1, Inf) / norm(g, Inf)
# nbhd = maximum(abs(dj / gj) for (dj, gj) in zip(vec1, g)) # TODO try this neighborhood
# elseif Cones.use_sqrt_oracles(cone) # NOTE can force factorization when we don't need it - better to use inv hess prod
# inv_hess_sqrt_prod!(vec2, vec1, cone)
# nbhd = norm(vec2)
else
inv_hess_prod!(vec2, vec1, cone)
nbhd_sqr = dot(vec2, vec1)
Expand Down Expand Up @@ -506,23 +503,31 @@ function arrow_prod(prod::AbstractVecOrMat{T}, arr::AbstractVecOrMat{T}, uu::T,
return prod
end

sqrt_pos(x::T) where {T <: Real} = sqrt(max(x, eps(T)))

# factorize arrow matrix
function arrow_sqrt(uu::T, uw::Vector{T}, ww::Vector{T}, rtuw::Vector{T}, rtww::Vector{T}) where {T <: Real}
@. rtww = sqrt_pos(ww)
tol = sqrt(eps(T))
any(<(tol), ww) && return zero(T)
@. rtww = sqrt(ww)
@. rtuw = uw / rtww
return sqrt_pos(uu - sum(abs2, rtuw))
diff = uu - sum(abs2, rtuw)
(diff < tol) && return zero(T)
return sqrt(diff)
end

# 2x2 block case
function arrow_sqrt(uu::T, uv::Vector{T}, uw::Vector{T}, vv::Vector{T}, vw::Vector{T}, ww::Vector{T}, rtuv::Vector{T}, rtuw::Vector{T}, rtvv::Vector{T}, rtvw::Vector{T}, rtww::Vector{T}) where {T <: Real}
@. rtww = sqrt_pos(ww)
tol = sqrt(eps(T))
any(<(tol), ww) && return zero(T)
@. rtww = sqrt(ww)
@. rtvw = vw / rtww
@. rtuw = uw / rtww
@. rtvv = sqrt_pos(vv - abs2(rtvw))
@. rtuv = vv - abs2(rtvw)
any(<(tol), rtuv) && return zero(T)
@. rtvv = sqrt(rtuv)
@. rtuv = (uv - rtvw * rtuw) / rtvv
return sqrt_pos(uu - sum(abs2, rtuv) - sum(abs2, rtuw))
diff = uu - sum(abs2, rtuv) - sum(abs2, rtuw)
(diff < tol) && return zero(T)
return sqrt(diff)
end

# lmul with lower Cholesky factor of arrow matrix
Expand Down
41 changes: 25 additions & 16 deletions src/Cones/epinorminf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mutable struct EpiNormInf{T <: Real, R <: RealOrComplex{T}} <: Cone{T}
inv_hess_updated::Bool
hess_aux_updated::Bool
hess_sqrt_aux_updated::Bool
use_hess_sqrt::Bool
inv_hess_aux_updated::Bool
is_feas::Bool
hess::Symmetric{T, SparseMatrixCSC{T, Int}}
Expand Down Expand Up @@ -69,7 +70,12 @@ use_heuristic_neighborhood(cone::EpiNormInf) = false

reset_data(cone::EpiNormInf) = (cone.feas_updated = cone.grad_updated = cone.hess_updated = cone.inv_hess_updated = cone.hess_aux_updated = cone.hess_sqrt_aux_updated = cone.inv_hess_aux_updated = false)

use_sqrt_oracles(cone::EpiNormInf) = true
function use_sqrt_oracles(cone::EpiNormInf)
cone.use_hess_sqrt || return false
cone.hess_sqrt_aux_updated || update_hess_sqrt_aux(cone)
!cone.use_hess_sqrt
return cone.use_hess_sqrt
end

# TODO only allocate the fields we use
function setup_extra_data(cone::EpiNormInf{T, R}) where {R <: RealOrComplex{T}} where {T <: Real}
Expand All @@ -93,6 +99,7 @@ function setup_extra_data(cone::EpiNormInf{T, R}) where {R <: RealOrComplex{T}}
cone.Hiuim = zeros(T, n)
cone.idet = zeros(T, n)
end
cone.use_hess_sqrt = true
return cone
end

Expand All @@ -109,7 +116,7 @@ function update_feas(cone::EpiNormInf{T}) where T
u = cone.point[1]
@views vec_copy_to!(cone.w, cone.point[2:end])

cone.is_feas = (u > eps(T) && abs2(u) - maximum(abs2, cone.w) > eps(T))
cone.is_feas = (u > eps(T) && u - norm(cone.w, Inf) > eps(T))

cone.feas_updated = true
return cone.is_feas
Expand Down Expand Up @@ -137,10 +144,9 @@ function update_grad(cone::EpiNormInf{T, R}) where {R <: RealOrComplex{T}} where
w = cone.w
den = cone.den

@inbounds for (j, wj) in enumerate(w)
absw = abs(wj)
den[j] = T(0.5) * (u - absw) * (u + absw)
end
usqr = abs2(u)
@. den = usqr - abs2(w)
den .*= T(0.5)
@. cone.uden = u / den
@. cone.wden = w / den
cone.grad[1] = (cone.n - 1) / u - sum(cone.uden)
Expand Down Expand Up @@ -236,6 +242,9 @@ function update_inv_hess_aux(cone::EpiNormInf{T}) where {T <: Real}
schur += inv(u2pwj2)
end
cone.schur = schur
if schur < zero(T)
@warn("bad schur $schur")
end

if cone.is_complex
@. cone.idet = cone.Hrere * cone.Himim - abs2(cone.Hreim)
Expand All @@ -253,23 +262,22 @@ function update_inv_hess(cone::EpiNormInf{T}) where {T <: Real}
Hi = cone.inv_hess.data
wden = cone.wden
u = cone.point[1]
schur = cone.schur

Hi[1, 1] = 1
Hi[1, 1] = inv(schur)
@inbounds for j in 1:cone.n
if cone.is_complex
Hi[1, 2j] = cone.Hiure[j]
Hi[1, 2j + 1] = cone.Hiuim[j]
Hi[2j, 1] = cone.Hiure[j]
Hi[2j + 1, 1] = cone.Hiuim[j]
else
Hi[1, j + 1] = cone.Hiure[j]
Hi[j + 1, 1] = cone.Hiure[j]
end
end
@. Hi[1, 2:end] = Hi[2:end, 1] / schur

rtschur = sqrt_pos(cone.schur)
Hi[1, :] ./= rtschur
@inbounds for j in 2:cone.dim, i in 2:j
Hi[i, j] = Hi[1, j] * Hi[1, i]
Hi[i, j] = Hi[j, 1] * Hi[1, i]
end
Hi[1, :] ./= rtschur

if cone.is_complex
@inbounds for j in 1:cone.n
Expand Down Expand Up @@ -330,12 +338,13 @@ function update_hess_sqrt_aux(cone::EpiNormInf)
else
cone.rtuu = arrow_sqrt(cone.Huu, cone.Hure, cone.Hrere, cone.rture, cone.rtrere)
end
cone.use_hess_sqrt = !iszero(cone.rtuu)
cone.hess_sqrt_aux_updated = true
return
end

function hess_sqrt_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiNormInf)
cone.hess_sqrt_aux_updated || update_hess_sqrt_aux(cone)
@assert cone.hess_sqrt_aux_updated && cone.use_hess_sqrt
if cone.is_complex
return arrow_sqrt_prod(prod, arr, cone.rtuu, cone.rture, cone.rtuim, cone.rtrere, cone.rtreim, cone.rtimim)
else
Expand All @@ -344,7 +353,7 @@ function hess_sqrt_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::Ep
end

function inv_hess_sqrt_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiNormInf)
cone.hess_sqrt_aux_updated || update_hess_sqrt_aux(cone)
@assert cone.hess_sqrt_aux_updated && cone.use_hess_sqrt
if cone.is_complex
return inv_arrow_sqrt_prod(prod, arr, cone.rtuu, cone.rture, cone.rtuim, cone.rtrere, cone.rtreim, cone.rtimim)
else
Expand Down
Loading

0 comments on commit 7b49c38

Please sign in to comment.