Skip to content

Commit

Permalink
Merge pull request #142 from mrc-ide/mrc-5454
Browse files Browse the repository at this point in the history
Simple event/root finding support
  • Loading branch information
weshinsley authored Nov 28, 2024
2 parents bdf52f4 + da2af0c commit e3e1b63
Show file tree
Hide file tree
Showing 14 changed files with 465 additions and 10 deletions.
1 change: 1 addition & 0 deletions .covrignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ src/malaria.cpp
src/sir.cpp
src/sirode.cpp
src/walk.cpp
inst/include/lostturnip.hpp
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ src/sir.cpp linguist-generated=true
src/sirode.cpp linguist-generated=true
src/walk.cpp linguist-generated=true
R/import-*.R linguist-generated=true
inst/include/lostturnip.hpp linguist-vendored=true linguist-generated=true
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: dust2
Title: Next Generation dust
Version: 0.3.10
Version: 0.3.11
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Imperial College of Science, Technology and Medicine",
Expand Down
5 changes: 4 additions & 1 deletion R/interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ dust_system_internals <- function(sys, include_coefficients = FALSE,
error = vnapply(dat, "[[", "error"),
n_steps = viapply(dat, "[[", "n_steps"),
n_steps_accepted = viapply(dat, "[[", "n_steps_accepted"),
n_steps_rejected = viapply(dat, "[[", "n_steps_rejected"))
n_steps_rejected = viapply(dat, "[[", "n_steps_rejected"),
events = I(lapply(dat, function(x) {
if (is.null(x$events)) NULL else as.data.frame(x$events)
})))
if (include_coefficients) {
ret$coefficients <- I(lapply(dat, "[[", "coefficients"))
}
Expand Down
1 change: 1 addition & 0 deletions inst/include/dust2/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <dust2/tools.hpp>
#include <dust2/zero.hpp>
#include <dust2/continuous/delays.hpp>
#include <dust2/continuous/events.hpp>
#include <cpp11/list.hpp>

// In an odd place, we might update that later too
Expand Down
63 changes: 63 additions & 0 deletions inst/include/dust2/continuous/events.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include <functional>
#include <vector>

namespace dust2 {
namespace ode {

// Do we detect an event when passing through the root as we increase
// (negative to positive), decrease (positive to negative) or both:
enum class root_type {
both,
increase,
decrease
};

// The actual logic for the above
template <typename real_type>
bool is_root(const real_type a, const real_type b, const root_type& root) {
switch(root) {
case root_type::both:
return a * b < 0;
case root_type::increase:
return a < 0 && b > 0;
case root_type::decrease:
return a > 0 && b < 0;
}
return false;
}

template <typename real_type>
struct event {
using test_type = std::function<real_type(const real_type, const real_type*)>;
using action_type = std::function<void(const real_type, const real_type, real_type*)>;
std::vector<size_t> index;
root_type root = root_type::both;
test_type test;
action_type action;

event(const std::vector<size_t>& index, test_type test, action_type action, root_type root = root_type::both) :
index(index), root(root), test(test), action(action) {
}

event(size_t index, action_type action, root_type root = root_type::both) :
event({index}, [](real_type t, const real_type* y) { return y[0]; }, action, root) {
}
};

template <typename real_type>
using events_type = std::vector<event<real_type>>;

template <typename real_type>
struct event_history_element {
real_type time;
size_t index;
real_type sign;
};

template <typename real_type>
using event_history = std::vector<event_history_element<real_type>>;

}
}
65 changes: 60 additions & 5 deletions inst/include/dust2/continuous/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
#include <cmath>
#include <stdexcept>
#include <vector>
#include <lostturnip.hpp>

#include <dust2/tools.hpp>
#include <dust2/zero.hpp>
#include <dust2/continuous/control.hpp>
#include <dust2/continuous/events.hpp>
#include <dust2/continuous/history.hpp>

namespace dust2 {
Expand All @@ -31,6 +34,7 @@ template <typename real_type>
struct internals {
history_step<real_type> last;
history<real_type> history_values;
event_history<real_type> events;

std::vector<real_type> dydt;
std::vector<real_type> step_times;
Expand Down Expand Up @@ -147,10 +151,12 @@ class solver {
// Take a single step
template <typename Rhs>
real_type step(real_type t, real_type t_end, real_type* y,
const events_type<real_type>& events,
ode::internals<real_type>& internals, Rhs rhs) {
auto success = false;
auto reject = false;
auto truncated = false;
auto event = false;
auto h = internals.step_size;

while (!success) {
Expand Down Expand Up @@ -178,13 +184,22 @@ class solver {
if (err <= 1) {
success = true;
update_interpolation(t, h, y, internals);
if (!events.empty()) {
const auto t_next = apply_events(t, h, y, events, internals);
if (t_next < t + h) {
event = true;
truncated = false;
h = t_next - t;
rhs(t_next, y_next_.data(), k2_.data());
}
}
accept(t, h, y, internals);
internals.n_steps_accepted++;
if (control_.debug_record_step_times) {
internals.step_times.push_back(truncated ? t_end : t + h);
}
internals.save_history();
if (!truncated) {
if (!truncated && !event) {
const auto fac_old =
std::max(internals.error, static_cast<real_type>(1e-4));
auto fac = fac11 / std::pow(fac_old, control_.beta);
Expand All @@ -209,11 +224,12 @@ class solver {
template <typename Rhs>
void run(real_type t, real_type t_end, real_type* y,
zero_every_type<real_type>& zero_every,
const events_type<real_type>& events,
ode::internals<real_type>& internals, Rhs rhs) {
if (control_.critical_times.empty()) {
while (t < t_end) {
apply_zero_every(t, y, zero_every, internals);
t = step(t, t_end, y, internals, rhs);
t = step(t, t_end, y, events, internals, rhs);
}
} else {
// Slightly more complex loop which ensures we never integrate
Expand All @@ -224,7 +240,7 @@ class solver {
auto t_end_i = (tc == tc_end || *tc >= t_end) ? t_end : *tc;
while (t < t_end) {
apply_zero_every(t, y, zero_every, internals);
t = step(t, t_end_i, y, internals, rhs);
t = step(t, t_end_i, y, events, internals, rhs);
if (t >= t_end_i && t < t_end) {
++tc;
t_end_i = (tc == tc_end || *tc >= t_end) ? t_end : *tc;
Expand Down Expand Up @@ -295,8 +311,6 @@ class solver {

private:
void update_interpolation(real_type t, real_type h, real_type* y, ode::internals<real_type>& internals) {
// We might want to only do this bit if we'll actually use the
// history, but it's pretty cheap really.
internals.last.t0 = t;
internals.last.t1 = t + h;
internals.last.h = h;
Expand All @@ -315,6 +329,47 @@ class solver {
std::copy_n(y_next_.begin(), n_variables_, y);
}

real_type apply_events(real_type t0, real_type h, const real_type* y,
const events_type<real_type>& events,
ode::internals<real_type>& internals) {
size_t idx_first = events.size();
real_type t1 = t0 + h;
real_type sign = 0;

for (size_t idx_event = 0; idx_event < events.size(); ++idx_event) {
const auto& e = events[idx_event];
// Use y_stiff as temporary space here, it's only used
// transiently and within the step
real_type * y_t = y_stiff_.data();
auto fn = [&](auto t) {
internals.last.interpolate(t, e.index, y_t);
return e.test(t, y_t);
};
const auto f_t0 = fn(t0);
const auto f_t1 = fn(t1);
if (is_root(f_t0, f_t1, e.root)) {
// These probably should move into the ode control, but there
// should really be any great need to change them, and the
// interpolation is expected to be quite fast and accurate.
constexpr real_type eps = 1e-6;
constexpr size_t steps = 100;
auto root = lostturnip::find_result<real_type>(fn, t0, t1, eps, steps);
idx_first = idx_event;
t1 = root.x;
sign = f_t0 < 0 ? 1 : -1;
}
if (idx_first < events.size()) {
internals.last.interpolate(t1, y_next_.data());
events[idx_first].action(t1, sign, y_next_.data());
// We need to modify the history here so that search will find
// the right point.
internals.last.t1 = t1;
internals.events.push_back({t1, idx_first, sign});
}
}
return t1;
}

void apply_zero_every(real_type t, real_type* y,
const zero_every_type<real_type>& zero_every,
ode::internals<real_type>& internals) {
Expand Down
7 changes: 5 additions & 2 deletions inst/include/dust2/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>
#include <dust2/continuous/control.hpp>
#include <dust2/continuous/delays.hpp>
#include <dust2/continuous/events.hpp>
#include <dust2/continuous/solver.hpp>
#include <dust2/errors.hpp>
#include <dust2/internals.hpp>
Expand Down Expand Up @@ -67,6 +68,7 @@ class dust_continuous {
errors_(n_particles_total_),
rng_(n_particles_total_, seed, deterministic),
delays_(do_delays<T>(shared_)),
events_(do_events<T>(shared_, internal_)),
solver_(n_groups_ * n_threads_, {n_state_ode_, control_}),
output_is_current_(n_groups_),
requires_initialise_(n_groups_, true) {
Expand All @@ -89,7 +91,7 @@ class dust_continuous {
const auto offset = k * n_state_;
real_type * y = state_data + offset;
try {
solver_[i].run(time_, time, y, zero_every_[group],
solver_[i].run(time_, time, y, zero_every_[group], events_[i],
ode_internals_[k],
rhs_(particle, group, thread));
} catch (std::exception const& e) {
Expand Down Expand Up @@ -130,7 +132,7 @@ class dust_continuous {
for (size_t step = 0; step < n_steps; ++step) {
const real_type t0 = t1;
t1 = (step == n_steps - 1) ? time : time_ + step * dt_;
solver_[i].run(t0, t1, y, zero_every_[group],
solver_[i].run(t0, t1, y, zero_every_[group], events_[i],
ode_internals_[k],
rhs_(particle, group, thread));
std::copy_n(y, n_state_ode_, y_other);
Expand Down Expand Up @@ -389,6 +391,7 @@ class dust_continuous {
dust2::utils::errors errors_;
monty::random::prng<rng_state_type> rng_;
std::vector<ode::delays<real_type>> delays_;
std::vector<ode::events_type<real_type>> events_;
std::vector<ode::solver<real_type>> solver_;
std::vector<bool> output_is_current_;
std::vector<bool> requires_initialise_;
Expand Down
35 changes: 35 additions & 0 deletions inst/include/dust2/properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <dust2/packing.hpp>
#include <dust2/continuous/delays.hpp>
#include <dust2/continuous/events.hpp>
#include <dust2/zero.hpp>
#include <vector>

Expand Down Expand Up @@ -44,6 +45,11 @@ struct test_has_delays: std::false_type {};
template <class T>
struct test_has_delays<T, std::void_t<decltype(T::delays)>>: std::true_type {};

template <class T, class = void>
struct test_has_events: std::false_type {};
template <class T>
struct test_has_events<T, std::void_t<decltype(T::events)>>: std::true_type {};

// These test that the signature of rhs and output consume the delays
// argument. Not especially lovely to read!
template <typename T>
Expand Down Expand Up @@ -90,6 +96,7 @@ struct properties {
// Because of the above these are now actual numbers rather than types; we may make this change everywhere...
static constexpr bool rhs_uses_delays = internals::test_rhs_uses_delays<T>();
static constexpr bool output_uses_delays = internals::test_output_uses_delays<T>();
using has_events = internals::test_has_events<T>;
};

// wrappers around some uses of member functions that may or may not
Expand Down Expand Up @@ -161,4 +168,32 @@ auto do_delays(const std::vector<typename T::shared_state>& shared) {
return std::vector<ode::delays<real_type>>(shared.size(), dust2::ode::delays<real_type>{{}});
}

template <typename T, typename std::enable_if<properties<T>::has_events::value, T>::type* = nullptr>
auto do_events(const std::vector<typename T::shared_state>& shared,
std::vector<typename T::internal_state>& internal) {
using real_type = typename T::real_type;
std::vector<ode::events_type<real_type>> ret;

const auto n_groups = shared.size();
const auto n_threads = internal.size();
ret.reserve(n_threads * n_groups);
auto iter = internal.begin();
for (size_t i = 0; i < n_threads; ++i) {
for (auto& s : shared) {
ret.push_back(T::events(s, *iter));
++iter;
}
}
return ret;
}

template <typename T, typename std::enable_if<!properties<T>::has_events::value, T>::type* = nullptr>
auto do_events(const std::vector<typename T::shared_state>& shared,
std::vector<typename T::internal_state>& internal) {
using real_type = typename T::real_type;
const auto len = shared.size() * internal.size();
const auto empty = dust2::ode::events_type<real_type>{{}};
return std::vector<ode::events_type<real_type>>(len, empty);
}

}
19 changes: 18 additions & 1 deletion inst/include/dust2/r/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals<real_type>& internals,
"n_steps_accepted"_nm = cpp11::as_sexp(internals.n_steps_accepted),
"n_steps_rejected"_nm = cpp11::as_sexp(internals.n_steps_rejected),
"coefficients"_nm = R_NilValue,
"history"_nm = R_NilValue};
"history"_nm = R_NilValue,
"events"_nm = R_NilValue};

if (include_coefficients) {
auto r_coef = cpp11::writable::doubles_matrix<>(internals.last.c1.size(), 5);
auto coef = REAL(r_coef);
Expand Down Expand Up @@ -107,6 +109,21 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals<real_type>& internals,
"coefficients"_nm = std::move(r_history_coef)};
ret["history"] = cpp11::as_sexp(r_history);
}
if (!internals.events.empty()) {
const auto n_events = internals.events.size();
auto r_event_time = cpp11::writable::doubles(n_events);
auto r_event_index = cpp11::writable::integers(n_events);
auto r_event_sign = cpp11::writable::doubles(n_events);
for (size_t i = 0; i < n_events; ++i) {
r_event_time[i] = internals.events[i].time;
r_event_index[i] = static_cast<int>(internals.events[i].index) + 1;
r_event_sign[i] = internals.events[i].sign;
}
auto r_events = cpp11::writable::list{"time"_nm = std::move(r_event_time),
"index"_nm = std::move(r_event_index),
"sign"_nm = std::move(r_event_sign)};
ret["events"] = cpp11::as_sexp(r_events);
}
return ret;
}

Expand Down
Loading

0 comments on commit e3e1b63

Please sign in to comment.