Skip to content

Commit

Permalink
linesearch : remove getter and setter for options, make struct public
Browse files Browse the repository at this point in the history
  • Loading branch information
ManifoldFR committed Oct 29, 2024
1 parent 78833d1 commit e9ccee0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
9 changes: 8 additions & 1 deletion bindings/python/expose-solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ void exposeSolver() {
.value("LDLT_PROXSUITE", LDLTChoice::PROXSUITE)
.export_values();

using LinesearchOptions = Linesearch<Scalar>::Options;
using Linesearch = Linesearch<Scalar>;
using LinesearchOptions = Linesearch::Options;
bp::class_<Linesearch>("Linesearch", bp::no_init)
.def(bp::init<const LinesearchOptions &>(("self"_a, "options")))
.def_readwrite("options", &Linesearch::options_);
bp::class_<ArmijoLinesearch<Scalar>, bp::bases<Linesearch>>(
"ArmijoLinesearch", bp::no_init)
.def(bp::init<const LinesearchOptions &>(("self"_a, "options")));
bp::class_<LinesearchOptions>("LinesearchOptions", "Linesearch options.",
bp::init<>(("self"_a), "Default constructor."))
.def_readwrite("armijo_c1", &LinesearchOptions::armijo_c1)
Expand Down
26 changes: 13 additions & 13 deletions include/proxsuite-nlp/linesearch-armijo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ template <typename Scalar>
class ArmijoLinesearch final : public Linesearch<Scalar> {
public:
using Base = Linesearch<Scalar>;
using Base::options_;
using FunctionSample = typename Base::FunctionSample;
using Polynomial = PolynomialTpl<Scalar>;
using VectorXs = typename math_types<Scalar>::VectorXs;
Expand All @@ -71,26 +72,26 @@ class ArmijoLinesearch final : public Linesearch<Scalar> {
break;
} catch (const std::runtime_error &e) {
alpha_try *= 0.5;
if (alpha_try <= options().alpha_min) {
alpha_try = options().alpha_min;
if (alpha_try <= options_.alpha_min) {
alpha_try = options_.alpha_min;
break;
}
}
}

if (std::abs(dphi0) < options().dphi_thresh) {
if (std::abs(dphi0) < options_.dphi_thresh) {
return latest.phi;
}

for (std::size_t i = 0; i < options().max_num_steps; i++) {
for (std::size_t i = 0; i < options_.max_num_steps; i++) {

const Scalar dM = latest.phi - phi0;
if (dM <= options().armijo_c1 * alpha_try * dphi0) {
if (dM <= options_.armijo_c1 * alpha_try * dphi0) {
break;
}

// compute next alpha try
LSInterpolation strat = options().interp_type;
LSInterpolation strat = options_.interp_type;
if (strat == LSInterpolation::BISECTION) {
alpha_try *= 0.5;
} else {
Expand Down Expand Up @@ -119,15 +120,15 @@ class ArmijoLinesearch final : public Linesearch<Scalar> {
}

alpha_try = this->minimize_interpolant(
strat, options().contraction_min * alpha_try,
options().contraction_max * alpha_try);
strat, options_.contraction_min * alpha_try,
options_.contraction_max * alpha_try);
}

if (std::isnan(alpha_try)) {
// handle NaN case
alpha_try = options().contraction_min * previous.alpha;
alpha_try = options_.contraction_min * previous.alpha;
} else {
alpha_try = std::max(alpha_try, options().alpha_min);
alpha_try = std::max(alpha_try, options_.alpha_min);
}

try {
Expand All @@ -137,11 +138,11 @@ class ArmijoLinesearch final : public Linesearch<Scalar> {
continue;
}

if (alpha_try <= options().alpha_min) {
if (alpha_try <= options_.alpha_min) {
break;
}
}
alpha_try = std::max(alpha_try, options().alpha_min);
alpha_try = std::max(alpha_try, options_.alpha_min);
return latest.phi;
}

Expand Down Expand Up @@ -221,7 +222,6 @@ class ArmijoLinesearch final : public Linesearch<Scalar> {
}

protected:
using Base::options;
Polynomial interpolant;
std::vector<FunctionSample> samples; // interpolation samples
};
Expand Down
2 changes: 0 additions & 2 deletions include/proxsuite-nlp/linesearch-base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ template <typename T> class Linesearch {
FunctionSample(T a, T v, T g) : alpha(a), phi(v), dphi(g), valid(true) {}
};

const Linesearch::Options &options() const { return options_; }
void setOptions(const Linesearch::Options &options) { options_ = options; }

void reset() {}

private:
Linesearch::Options options_;
};

Expand Down

0 comments on commit e9ccee0

Please sign in to comment.