Skip to content

Commit

Permalink
start adding initial central point finding
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscoey committed Jul 7, 2020
1 parent bb654f5 commit 3057c2f
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/robustgeomprog/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ example_tests(::Type{RobustGeomProgJuMP{Float64}}, ::FastInstances) = begin
options = (tol_feas = 1e-6, tol_rel_opt = 1e-6, tol_abs_opt = 1e-6)
return [
((5, 10), nothing, options),
((5, 10), ClassicConeOptimizer), nothing, options),
((5, 10), ClassicConeOptimizer, options),
((10, 20), nothing, options),
((20, 40), nothing, options),
((40, 80),),
Expand Down
46 changes: 45 additions & 1 deletion src/Cones/Cones.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function update_hess_fact(cone::Cone{T}; recover::Bool = true) where {T <: Real}
recover || return false
# TODO if Chol, try adding sqrt(eps(T)) to diag and re-factorize
if T <: BlasReal && cone.hess_fact_cache isa DensePosDefCache{T}
@warn("switching Hessian cache from Cholesky to Bunch Kaufman")
# @warn("switching Hessian cache from Cholesky to Bunch Kaufman")
cone.hess_fact_cache = DenseSymCache{T}()
load_matrix(cone.hess_fact_cache, cone.hess)
else
Expand Down Expand Up @@ -253,6 +253,50 @@ end
# # return (nbhd < T(0.5))
# end

# newton for central initial point
# TODO don't run if cone has known central initial point
# TODO remove allocs
function set_central_point(cone::Cone{T}) where {T <: Real}
tol = cbrt(eps(T)) # TODO adjust
max_iter = 10 # TODO make it depend on sqrt(nu)?
damp_tol = 0.2 # TODO tune
nu = get_nu(cone)

curr = zeros(T, dimension(cone))
set_initial_point(curr, cone)
curr .*= sqrt(nu / sum(abs2, curr)) # rescale norm as a heuristic

dir = similar(curr)
iter = 0
while iter < max_iter
load_point(cone, curr)
reset_data(cone)
@assert is_feas(cone)
g = grad(cone)

tmp = -curr - g
# @show norm(tmp)
if norm(tmp, Inf) < tol # TODO tune
iter > 0 && println("final iter $iter, $(norm(tmp, Inf))")
break
end

dir .= cholesky!(Symmetric(hess(cone) + I)) \ tmp # TODO make more efficient, maybe add to cone.hess directly and use hess fact
# inv_hess_prod!(dir, tmp, cone) # cannot use

nnorm = dot(tmp, dir)
@assert nnorm > 0
alpha = (abs(nnorm) > damp_tol ? inv(1 + abs(nnorm)) : one(T))
# @show alpha
# @show nnorm

curr = curr + alpha * dir
iter += 1
end
# @show norm(curr + grad(cone))

return curr
end

# utilities for arrays

Expand Down
15 changes: 13 additions & 2 deletions src/Solvers/initialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,28 @@ function initialize_cone_point(cones::Vector{Cones.Cone{T}}, cone_idxs::Vector{U
q = isempty(cones) ? 0 : sum(Cones.dimension, cones)
point = Models.Point(T[], T[], zeros(T, q), zeros(T, q), cones, cone_idxs)

# TODO cleanup
use_newton = true
# use_newton = false

for (k, cone_k) in enumerate(cones)
Cones.setup_data(cone_k)
Cones.set_timer(cone_k, timer)
primal_k = point.primal_views[k]
Cones.set_initial_point(primal_k, cone_k)
if use_newton
primal_k .= Cones.set_central_point(cone_k) # TODO pass in arg?
else
Cones.set_initial_point(primal_k, cone_k)
end
Cones.load_point(cone_k, primal_k)
dual_k = point.dual_views[k]
@assert Cones.is_feas(cone_k)
g = Cones.grad(cone_k)
@. dual_k = -g
Cones.load_dual_point(cone_k, dual_k)
# if use_newton
# @show norm(primal_k - dual_k)
# # @assert primal_k ≈ dual_k rtol=eps(T)^0.25 # TODO delete
# end
hasfield(typeof(cone_k), :hess_fact_cache) && @assert Cones.update_hess_fact(cone_k)
end

Expand Down
8 changes: 4 additions & 4 deletions src/Solvers/stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ function update_rhs_predcorr(stepper::CombinedStepper{T}, solver::Solver{T}) whe
# end
if corr_viol < 0.001
@. stepper.s_rhs_k[k] += H_prim_dir_k + corr_k
else
println("skip pred-corr: $corr_viol")
# else
# println("skip pred-corr: $corr_viol")
end
end

Expand Down Expand Up @@ -444,8 +444,8 @@ function update_rhs_centcorr(stepper::CombinedStepper{T}, solver::Solver{T}) whe
# end
if corr_viol < 0.001
stepper.s_rhs_k[k] .+= corr_k
else
println("skip cent-corr: $corr_viol")
# else
# println("skip cent-corr: $corr_viol")
end
end

Expand Down

0 comments on commit 3057c2f

Please sign in to comment.