Skip to content

Commit

Permalink
Merge pull request #145 from mrc-ide/mrc-5992
Browse files Browse the repository at this point in the history
Make control a public field
  • Loading branch information
weshinsley authored Dec 13, 2024
2 parents 5501f39 + b136a9f commit 91b04a6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 30 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: dust2
Title: Next Generation dust
Version: 0.3.14
Version: 0.3.15
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Imperial College of Science, Technology and Medicine",
Expand Down
53 changes: 25 additions & 28 deletions inst/include/dust2/continuous/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,20 @@ struct internals {
template <typename real_type>
class solver {
public:
ode::control<real_type> control;

solver(size_t n_variables, ode::control<real_type> control) :
control(control),
n_variables_(n_variables),
control_(control),
y_next_(n_variables_),
y_stiff_(n_variables_),
k2_(n_variables_),
k3_(n_variables_),
k4_(n_variables_),
k5_(n_variables_),
k6_(n_variables_),
facc1_(1 / control_.factor_min),
facc2_(1 / control_.factor_max) {
}

// TODO: probably better to make this a public field (mrc-5992)
auto& control() {
return control_;
facc1_(1 / control.factor_min),
facc2_(1 / control.factor_max) {
}

template <typename Rhs>
Expand Down Expand Up @@ -137,8 +134,8 @@ class solver {
}

// Compute error:
const auto atol = control_.atol;
const auto rtol = control_.rtol;
const auto atol = control.atol;
const auto rtol = control.rtol;
real_type err = 0.0;
for (size_t i = 0; i < n_variables_; ++i) {
auto sk =
Expand All @@ -160,12 +157,12 @@ class solver {
auto h = internals.step_size;

while (!success) {
if (internals.n_steps > control_.max_steps) {
if (internals.n_steps > control.max_steps) {
// throw a nicer error for all of these, with the current
// time etc.
throw std::runtime_error("too many steps");
}
if (h < control_.step_size_min) {
if (h < control.step_size_min) {
throw std::runtime_error("step too small");
}
if (h <= std::abs(t) * std::numeric_limits<real_type>::epsilon()) {
Expand All @@ -179,7 +176,7 @@ class solver {
const auto err = try_step(t, h, y, internals.dydt.data(),
internals.last.c5.data(), rhs);
internals.n_steps++;
const auto fac11 = std::pow(err, control_.constant);
const auto fac11 = std::pow(err, control.constant);

if (err <= 1) {
success = true;
Expand All @@ -195,17 +192,17 @@ class solver {
}
accept(t, h, y, internals);
internals.n_steps_accepted++;
if (control_.debug_record_step_times) {
if (control.debug_record_step_times) {
internals.step_times.push_back(truncated ? t_end : t + h);
}
internals.save_history();
if (!truncated && !event) {
const auto fac_old =
std::max(internals.error, static_cast<real_type>(1e-4));
auto fac = fac11 / std::pow(fac_old, control_.beta);
fac = clamp(fac / control_.factor_safe, facc2_, facc1_);
auto fac = fac11 / std::pow(fac_old, control.beta);
fac = clamp(fac / control.factor_safe, facc2_, facc1_);
const auto h_new = h / fac;
const auto h_max = reject ? h : control_.step_size_max;
const auto h_max = reject ? h : control.step_size_max;
internals.step_size = std::min(h_new, h_max);
internals.error = err;
}
Expand All @@ -214,7 +211,7 @@ class solver {
if (internals.n_steps_accepted >= 1) {
internals.n_steps_rejected++;
}
h /= std::min(facc1_, fac11 / control_.factor_safe);
h /= std::min(facc1_, fac11 / control.factor_safe);
}
}

Expand All @@ -226,7 +223,7 @@ class solver {
zero_every_type<real_type>& zero_every,
const events_type<real_type>& events,
ode::internals<real_type>& internals, Rhs rhs) {
if (control_.critical_times.empty()) {
if (control.critical_times.empty()) {
while (t < t_end) {
apply_zero_every(t, y, zero_every, internals);
t = step(t, t_end, y, events, internals, rhs);
Expand All @@ -235,8 +232,8 @@ class solver {
// Slightly more complex loop which ensures we never integrate
// over the times within our critical times. The upper loop is
// a special case of this but is kept simple.
auto tc_end = control_.critical_times.end();
auto tc = std::upper_bound(control_.critical_times.begin(), tc_end, t);
auto tc_end = control.critical_times.end();
auto tc = std::upper_bound(control.critical_times.begin(), tc_end, t);
auto t_end_i = (tc == tc_end || *tc >= t_end) ? t_end : *tc;
while (t < t_end) {
apply_zero_every(t, y, zero_every, internals);
Expand All @@ -254,7 +251,7 @@ class solver {
void initialise(const real_type t, const real_type* y,
ode::internals<real_type>& internals, Rhs rhs) {
internals.reset(y);
if (control_.debug_record_step_times) {
if (control.debug_record_step_times) {
internals.step_times.push_back(t);
}
auto f0 = internals.dydt.data();
Expand All @@ -270,14 +267,14 @@ class solver {
real_type norm_y = 0.0;

for (size_t i = 0; i < n_variables_; ++i) {
const real_type sk = control_.atol + control_.rtol * std::abs(y[i]);
const real_type sk = control.atol + control.rtol * std::abs(y[i]);
norm_f += square(f0[i] / sk);
norm_y += square(y[i] / sk);
}
// Magic numbers here, from Hairer
real_type h = (norm_f <= 1e-10 || norm_y <= 1e-10) ?
1e-6 : std::sqrt(norm_y / norm_f) * 0.01;
h = std::min(h, control_.step_size_max);
h = std::min(h, control.step_size_max);

// Perform an explicit Euler step
for (size_t i = 0; i < n_variables_; ++i) {
Expand All @@ -288,7 +285,7 @@ class solver {
// Estimate the second derivative of the solution:
real_type der2 = 0.0;
for (size_t i = 0; i < n_variables_; ++i) {
const real_type sk = control_.atol + control_.rtol * std::abs(y[i]);
const real_type sk = control.atol + control.rtol * std::abs(y[i]);
der2 += square((f1[i] - f0[i]) / sk);
}
der2 = std::sqrt(der2) / h;
Expand All @@ -304,8 +301,8 @@ class solver {
if (!std::isfinite(h)) {
throw std::runtime_error("Initial step size was not finite");
}
h = std::max(h, control_.step_size_min * 100);
h = std::min(h, control_.step_size_max);
h = std::max(h, control.step_size_min * 100);
h = std::min(h, control.step_size_max);
internals.step_size = h;
}

Expand Down Expand Up @@ -427,8 +424,8 @@ class solver {
}
}

private:
size_t n_variables_;
ode::control<real_type> control_;
std::vector<real_type> y_next_;
std::vector<real_type> y_stiff_;
std::vector<real_type> k2_;
Expand Down
2 changes: 1 addition & 1 deletion inst/include/dust2/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ class dust_continuous {
for (size_t i = 0, k = 0; i < n_threads_; ++i) {
for (size_t group = 0; group < n_groups_; ++group, ++k) {
delay_result_.push_back(delays_[group].result());
solver_[k].control().step_size_max =
solver_[k].control.step_size_max =
delays_[group].step_size_max(control_.step_size_max,
rhs_uses_delays_);
}
Expand Down

0 comments on commit 91b04a6

Please sign in to comment.