diff --git a/examples/robustgeomprog/JuMP.jl b/examples/robustgeomprog/JuMP.jl index ffc0332f0..29e311749 100644 --- a/examples/robustgeomprog/JuMP.jl +++ b/examples/robustgeomprog/JuMP.jl @@ -28,7 +28,7 @@ example_tests(::Type{RobustGeomProgJuMP{Float64}}, ::FastInstances) = begin options = (tol_feas = 1e-6, tol_rel_opt = 1e-6, tol_abs_opt = 1e-6) return [ ((5, 10), nothing, options), - ((5, 10), ClassicConeOptimizer), nothing, options), + ((5, 10), ClassicConeOptimizer, options), ((10, 20), nothing, options), ((20, 40), nothing, options), ((40, 80),), diff --git a/src/Cones/Cones.jl b/src/Cones/Cones.jl index 4e22bd8d9..168243374 100644 --- a/src/Cones/Cones.jl +++ b/src/Cones/Cones.jl @@ -100,7 +100,7 @@ function update_hess_fact(cone::Cone{T}; recover::Bool = true) where {T <: Real} recover || return false # TODO if Chol, try adding sqrt(eps(T)) to diag and re-factorize if T <: BlasReal && cone.hess_fact_cache isa DensePosDefCache{T} - @warn("switching Hessian cache from Cholesky to Bunch Kaufman") + # @warn("switching Hessian cache from Cholesky to Bunch Kaufman") cone.hess_fact_cache = DenseSymCache{T}() load_matrix(cone.hess_fact_cache, cone.hess) else @@ -253,6 +253,50 @@ end # # return (nbhd < T(0.5)) # end +# newton for central initial point +# TODO don't run if cone has known central initial point +# TODO remove allocs +function set_central_point(cone::Cone{T}) where {T <: Real} + tol = cbrt(eps(T)) # TODO adjust + max_iter = 10 # TODO make it depend on sqrt(nu)? + damp_tol = 0.2 # TODO tune + nu = get_nu(cone) + + curr = zeros(T, dimension(cone)) + set_initial_point(curr, cone) + curr .*= sqrt(nu / sum(abs2, curr)) # rescale norm as a heuristic + + dir = similar(curr) + iter = 0 + while iter < max_iter + load_point(cone, curr) + reset_data(cone) + @assert is_feas(cone) + g = grad(cone) + + tmp = -curr - g + # @show norm(tmp) + if norm(tmp, Inf) < tol # TODO tune + iter > 0 && println("final iter $iter, $(norm(tmp, Inf))") + break + end + + dir .= cholesky!(Symmetric(hess(cone) + I)) \ tmp # TODO make more efficient, maybe add to cone.hess directly and use hess fact + # inv_hess_prod!(dir, tmp, cone) # cannot use + + nnorm = dot(tmp, dir) + @assert nnorm > 0 + alpha = (abs(nnorm) > damp_tol ? inv(1 + abs(nnorm)) : one(T)) + # @show alpha + # @show nnorm + + curr = curr + alpha * dir + iter += 1 + end + # @show norm(curr + grad(cone)) + + return curr +end # utilities for arrays diff --git a/src/Solvers/initialize.jl b/src/Solvers/initialize.jl index e28773c78..dbfd395da 100644 --- a/src/Solvers/initialize.jl +++ b/src/Solvers/initialize.jl @@ -9,17 +9,28 @@ function initialize_cone_point(cones::Vector{Cones.Cone{T}}, cone_idxs::Vector{U q = isempty(cones) ? 0 : sum(Cones.dimension, cones) point = Models.Point(T[], T[], zeros(T, q), zeros(T, q), cones, cone_idxs) + # TODO cleanup + use_newton = true + # use_newton = false + for (k, cone_k) in enumerate(cones) Cones.setup_data(cone_k) Cones.set_timer(cone_k, timer) primal_k = point.primal_views[k] - Cones.set_initial_point(primal_k, cone_k) + if use_newton + primal_k .= Cones.set_central_point(cone_k) # TODO pass in arg? + else + Cones.set_initial_point(primal_k, cone_k) + end Cones.load_point(cone_k, primal_k) dual_k = point.dual_views[k] @assert Cones.is_feas(cone_k) g = Cones.grad(cone_k) @. dual_k = -g - Cones.load_dual_point(cone_k, dual_k) + # if use_newton + # @show norm(primal_k - dual_k) + # # @assert primal_k ≈ dual_k rtol=eps(T)^0.25 # TODO delete + # end hasfield(typeof(cone_k), :hess_fact_cache) && @assert Cones.update_hess_fact(cone_k) end diff --git a/src/Solvers/stepper.jl b/src/Solvers/stepper.jl index 018ff9c59..60a12e70e 100644 --- a/src/Solvers/stepper.jl +++ b/src/Solvers/stepper.jl @@ -386,8 +386,8 @@ function update_rhs_predcorr(stepper::CombinedStepper{T}, solver::Solver{T}) whe # end if corr_viol < 0.001 @. stepper.s_rhs_k[k] += H_prim_dir_k + corr_k - else - println("skip pred-corr: $corr_viol") + # else + # println("skip pred-corr: $corr_viol") end end @@ -444,8 +444,8 @@ function update_rhs_centcorr(stepper::CombinedStepper{T}, solver::Solver{T}) whe # end if corr_viol < 0.001 stepper.s_rhs_k[k] .+= corr_k - else - println("skip cent-corr: $corr_viol") + # else + # println("skip cent-corr: $corr_viol") end end