Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Corrected, improved and cleaned BackTracking #172

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 157 additions & 70 deletions src/backtracking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,111 +8,198 @@ there exists a factor ρ = ρ(c₁) such that α' ≦ ρ α.

This is a modification of the algorithm described in Nocedal Wright (2nd ed), Sec. 3.5.
"""
@with_kw struct BackTracking{TF, TI}
c_1::TF = 1e-4
ρ_hi::TF = 0.5
ρ_lo::TF = 0.1
iterations::TI = 1_000
order::TI = 3
maxstep::TF = Inf
struct BackTracking{TF,TI}
c_1::TF
ρ_hi::TF
ρ_lo::TF
iterations::TI
order::TI
maxstep::TF
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the rest of the package uses 4-space indentation, this should too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out. The last commit solves this issue.

BackTracking{TF}(args...; kwargs...) where TF = BackTracking{TF,Int}(args...; kwargs...)

function (ls::BackTracking)(df::AbstractObjective, x::AbstractArray{T}, s::AbstractArray{T},
α_0::Tα = real(T)(1), x_new::AbstractArray{T} = similar(x), ϕ_0 = nothing, dϕ_0 = nothing, alphamax = convert(real(T), Inf)) where {T, Tα}
function BackTracking(args...; kwargs...)
BackTracking{Float64}(args...; kwargs...)
end

function BackTracking{TF}(args...; kwargs...) where {TF}
BackTracking{TF,Int}(args...; kwargs...)
end

function BackTracking{TF,TI}(;
c_1::Real=1.0e-4, ρ_hi::Real=0.5, ρ_lo::Real=0.1,
iterations::Integer=1_000, order::Integer=3, maxstep::Real=Inf
) where {TF,TI}
# Impose 0 < ρ_hi < 1
if (ρ_hi <= 0) || (ρ_hi >= 1)
ρ_hi = 1 / 2
@warn """
The upper bound for the backtracking factor has to lie between
0 and 1.
Setting ρ_hi = $(ρ_hi).
"""
end

# Impose 0 < ρ_lo <= ρ_hi < 1
if (ρ_lo <= 0) || (ρ_lo >= 1) || (ρ_lo > ρ_hi)
ρ_lo = ρ_hi / 5
@warn """
The lower bound for the backtracking factor has to lie between
0 and 1, and be smaller than the upper bound.
Setting ρ_lo = $(ρ_lo).
"""
end

# Impose positive number of maximum iterations
if (iterations <= 0) || !isinteger(iterations)
iterations = trunc(Int, iterations)
if iterations <= 0
iterations = 1_000
end
@warn """
The number of maximum iterations has to be a positive integer.
Setting iterations = $(iterations).
"""
end

# Impose order in (2, 3)
if order != 2 && order != 3
order = 3
@warn """
The order has to be either 2 or 3.
Setting order = $(order).
"""
end

# Impose maxstep
if !isreal(maxstep)
maxstep = Inf
@warn """
The maximum step size has to be real.
Setting maxstep = $(maxstep).
"""
end

# Impose c_1 > 0
if c_1 < 0
c_1 = 1.0e-4
@warn """
The Armijo constant hast to be positive.
Setting c_1 = $(c_1).
"""
end

# # Impose backtracking factor (for order = 2)
# # The quadratic update rule come with a backtracking factor
# # ρ = 1 / 2 / (1 - c_1).
# # We need c_1 > 0, and we want ρ < 1,
# # so 0 < c_1 < 1/2 and 1/2 < ρ < 1.
# ρ = ρ_hi # Could take another choice here
# c_1_ρ = 1 - 1 / (2 * ρ)
# if c_1 > c_1_ρ
# c_1 = c_1_ρ
# @warn """
# The Armijo constant c_1 is too large.
# Setting c_1 = $(c_1_ρ).
# """
# end

BackTracking{TF,TI}(c_1, ρ_hi, ρ_lo, iterations, order, maxstep)
end

function (ls::BackTracking)(
df::AbstractObjective, x::AbstractArray{T}, s::AbstractArray{T},
α_0::Tα=real(T)(1), x_new::AbstractArray{T}=similar(x),
ϕ_0=nothing, dϕ_0=nothing, alphamax=convert(real(T), Inf)
) where {T,Tα}
ϕ, dϕ = make_ϕ_dϕ(df, x_new, x, s)

if ϕ_0 == nothing
if isnothing(ϕ_0)
ϕ_0 = ϕ(Tα(0))
end
if dϕ_0 == nothing
if isnothing(dϕ_0)
dϕ_0 = dϕ(Tα(0))
end

α_0 = min(α_0, min(alphamax, ls.maxstep / norm(s, Inf)))
ls(ϕ, α_0, ϕ_0, dϕ_0)
end

(ls::BackTracking)(ϕ, dϕ, ϕdϕ, αinitial, ϕ_0, dϕ_0) = ls(ϕ, αinitial, ϕ_0, dϕ_0)
function (ls::BackTracking)(ϕ, dϕ, ϕdϕ, α_0, ϕ_0, dϕ_0)
ls(ϕ, α_0, ϕ_0, dϕ_0)
end

# TODO: Should we deprecate the interface that only uses the ϕ argument?
function (ls::BackTracking)(ϕ, αinitial::Tα, ϕ_0, dϕ_0) where
function (ls::BackTracking)(ϕ, α_0::Tα, ϕ_0, dϕ_0) where {Tα}
@unpack c_1, ρ_hi, ρ_lo, iterations, order = ls
ε = eps(real(Tα))

iterfinitemax = -log2(eps(real(Tα)))

@assert order in (2,3)
# Check the input is valid, and modify otherwise
#backtrack_condition = 1.0 - 1.0/(2*ρ) # want guaranteed backtrack factor
#if c_1 >= backtrack_condition
# warn("""The Armijo constant c_1 is too large; replacing it with
# $(backtrack_condition)""")
# c_1 = backtrack_condition
#end

# Count the total number of iterations
iteration = 0
# Initialise α_1 and α_2
α_1, α_2 = α_0, α_0
ϕ_1, ϕ_2 = ϕ_0, ϕ_0

ϕx_0, ϕx_1 = ϕ_0, ϕ_0

α_1, α_2 = αinitial, αinitial

ϕx_1 = ϕ(α_1)

# Hard-coded backtrack until we find a finite function value
# Backtrack until ϕ(α_2) is finite
iterfinite = 0
while !isfinite(ϕx_1) && iterfinite < iterfinitemax
iterfinitemax = -log2(ε)
ϕ_2 = ϕ(α_1)
while !isfinite(ϕ_2) && iterfinite < iterfinitemax
iterfinite += 1
α_1 = α_2
α_2 = α_1/2

ϕx_1 = ϕ(α_2)
α_2 = α_1 / 2
ϕ_2 = ϕ(α_2)
end

# Backtrack until we satisfy sufficient decrease condition
while ϕx_1 > ϕ_0 + c_1 * α_2 * dϕ_0
# Increment the number of steps we've had to perform
# Backtrack until sufficient decrease
iteration = 0
while (ϕ_2 > ϕ_0 + c_1 * α_2 * dϕ_0) && (iteration <= iterations)
iteration += 1

# Ensure termination
if iteration > iterations
throw(LineSearchException("Linesearch failed to converge, reached maximum iterations $(iterations).",
α_2))
end

# Shrink proposed step-size:
if order == 2 || iteration == 1
# backtracking via quadratic interpolation:
if (order == 2) || (iteration == 1)
# Backtracking via quadratic interpolation:
# This interpolates the available data
# f(0), f'(0), f(α)
# with a quadractic which is then minimised; this comes with a
# guaranteed backtracking factor 0.5 * (1-c_1)^{-1} which is < 1
# provided that c_1 < 1/2; the backtrack_condition at the beginning
# of the function guarantees at least a backtracking factor ρ.
α_tmp = - (dϕ_0 * α_2^2) / ( 2 * (ϕx_1 - ϕ_0 - dϕ_0*α_2) )
# ϕ(0), ϕ'(0), ϕ(α)
# with a quadractic which is then minimised.
α_tmp = -(dϕ_0 * α_2^2) / (2 * (ϕ_2 - ϕ_0 - dϕ_0 * α_2))
else
div = one(Tα) / (α_1^2 * α_2^2 * (α_2 - α_1))
a = (α_1^2*(ϕx_1 - ϕ_0 - dϕ_0*α_2) - α_2^2*(ϕx_0 - ϕ_0 - dϕ_0*α_1))*div
b = (-α_1^3*(ϕx_1 - ϕ_0 - dϕ_0*α_2) + α_2^3*(ϕx_0 - ϕ_0 - dϕ_0*α_1))*div
# Backtracking via cubic interpolation:
# This interpolates the available data
# ϕ(0), ϕ'(0), ϕ(α_1), ϕ(α_2)
# with a cubic function which is then minimised.
α_1², α_2² = α_1^2, α_2^2
α_1³, α_2³ = α_1² * α_1, α_2² * α_2

δ_1 = ϕ_1 - ϕ_0 - dϕ_0 * α_1
δ_2 = ϕ_2 - ϕ_0 - dϕ_0 * α_2

if isapprox(a, zero(a), atol=eps(real(Tα)))
α_tmp = dϕ_0 / (2*b)
invdet = one(Tα) / (α_1² * α_2² * (α_2 - α_1))

a = (α_1² * δ_2 - α_2² * δ_1) * invdet
b = (α_2³ * δ_1 - α_1³ * δ_2) * invdet

if isapprox(a, zero(Tα), atol=ε)
# Degenerate quadratic case
α_tmp = -dϕ_0 / (2 * b)
else
# discriminant
d = max(b^2 - 3*a*dϕ_0, Tα(0))
# quadratic equation root
α_tmp = (-b + sqrt(d)) / (3*a)
# General cubic case, avoiding numerical cancellation
Δ = max(b^2 - 3 * a * dϕ_0, zero(Tα))
α_tmp = -dϕ_0 / (b + sqrt(Δ))
end
end

α_1 = α_2
# Clamp α_tmp to avoid too small / large reductions
α_tmp = NaNMath.min(α_tmp, α_2 * ρ_hi)
α_tmp = NaNMath.max(α_tmp, α_2 * ρ_lo)

α_tmp = NaNMath.min(α_tmp, α_2*ρ_hi) # avoid too small reductions
α_2 = NaNMath.max(α_tmp, α_2*ρ_lo) # avoid too big reductions
# Update (α_1, α_2)
α_1, α_2 = α_2, α_tmp
ϕ_1, ϕ_2 = ϕ_2, ϕ(α_2)
end

# Evaluate f(x) at proposed position
ϕx_0, ϕx_1 = ϕx_1, ϕ(α_2)
# Ensure termination
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I think this will cause the whole optimize call to fail with this error in Optim.jl, right?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. This is already what happens in the current implementation. I just moved current lines 77-81 at the end of the loop.

if iteration > iterations
msg = "Linesearch failed to converge, reached maximum iterations $(iterations)."
throw(LineSearchException(msg, α_2))
end

return α_2, ϕx_1
return α_2, ϕ_2
end