Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple event/root finding support #142

Merged
merged 19 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .covrignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ src/malaria.cpp
src/sir.cpp
src/sirode.cpp
src/walk.cpp
inst/include/lostturnip.hpp
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ src/sir.cpp linguist-generated=true
src/sirode.cpp linguist-generated=true
src/walk.cpp linguist-generated=true
R/import-*.R linguist-generated=true
inst/include/lostturnip.hpp linguist-vendored=true linguist-generated=true
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: dust2
Title: Next Generation dust
Version: 0.3.10
Version: 0.3.11
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Imperial College of Science, Technology and Medicine",
Expand Down
5 changes: 4 additions & 1 deletion R/interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ dust_system_internals <- function(sys, include_coefficients = FALSE,
error = vnapply(dat, "[[", "error"),
n_steps = viapply(dat, "[[", "n_steps"),
n_steps_accepted = viapply(dat, "[[", "n_steps_accepted"),
n_steps_rejected = viapply(dat, "[[", "n_steps_rejected"))
n_steps_rejected = viapply(dat, "[[", "n_steps_rejected"),
events = I(lapply(dat, function(x) {
if (is.null(x$events)) NULL else as.data.frame(x$events)
})))
if (include_coefficients) {
ret$coefficients <- I(lapply(dat, "[[", "coefficients"))
}
Expand Down
1 change: 1 addition & 0 deletions inst/include/dust2/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <dust2/tools.hpp>
#include <dust2/zero.hpp>
#include <dust2/continuous/delays.hpp>
#include <dust2/continuous/events.hpp>
#include <cpp11/list.hpp>

// In an odd place, we might update that later too
Expand Down
63 changes: 63 additions & 0 deletions inst/include/dust2/continuous/events.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include <functional>
#include <vector>

namespace dust2 {
namespace ode {

// Do we detect an event when passing through the root as we increase
// (negative to positive), decrease (positive to negative) or both:
enum class root_type {
both,
increase,
decrease
};

// The actual logic for the above
template <typename real_type>
bool is_root(const real_type a, const real_type b, const root_type& root) {
switch(root) {
case root_type::both:
return a * b < 0;
case root_type::increase:
return a < 0 && b > 0;
case root_type::decrease:
return a > 0 && b < 0;
}
return false;
}

template <typename real_type>
struct event {
using test_type = std::function<real_type(const real_type, const real_type*)>;
using action_type = std::function<void(const real_type, const real_type, real_type*)>;
std::vector<size_t> index;
root_type root = root_type::both;
test_type test;
action_type action;

event(const std::vector<size_t>& index, test_type test, action_type action, root_type root = root_type::both) :
index(index), root(root), test(test), action(action) {
}

event(size_t index, action_type action, root_type root = root_type::both) :
event({index}, [](real_type t, const real_type* y) { return y[0]; }, action, root) {
}
};

template <typename real_type>
using events_type = std::vector<event<real_type>>;

template <typename real_type>
struct event_history_element {
real_type time;
size_t index;
real_type sign;
};

template <typename real_type>
using event_history = std::vector<event_history_element<real_type>>;

}
}
65 changes: 60 additions & 5 deletions inst/include/dust2/continuous/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
#include <cmath>
#include <stdexcept>
#include <vector>
#include <lostturnip.hpp>

#include <dust2/tools.hpp>
#include <dust2/zero.hpp>
#include <dust2/continuous/control.hpp>
#include <dust2/continuous/events.hpp>
#include <dust2/continuous/history.hpp>

namespace dust2 {
Expand All @@ -31,6 +34,7 @@ template <typename real_type>
struct internals {
history_step<real_type> last;
history<real_type> history_values;
event_history<real_type> events;

std::vector<real_type> dydt;
std::vector<real_type> step_times;
Expand Down Expand Up @@ -147,10 +151,12 @@ class solver {
// Take a single step
template <typename Rhs>
real_type step(real_type t, real_type t_end, real_type* y,
const events_type<real_type>& events,
ode::internals<real_type>& internals, Rhs rhs) {
auto success = false;
auto reject = false;
auto truncated = false;
auto event = false;
auto h = internals.step_size;

while (!success) {
Expand Down Expand Up @@ -178,13 +184,22 @@ class solver {
if (err <= 1) {
success = true;
update_interpolation(t, h, y, internals);
if (!events.empty()) {
const auto t_next = apply_events(t, h, y, events, internals);
if (t_next < t + h) {
event = true;
truncated = false;
h = t_next - t;
rhs(t_next, y_next_.data(), k2_.data());
}
}
accept(t, h, y, internals);
internals.n_steps_accepted++;
if (control_.debug_record_step_times) {
internals.step_times.push_back(truncated ? t_end : t + h);
}
internals.save_history();
if (!truncated) {
if (!truncated && !event) {
const auto fac_old =
std::max(internals.error, static_cast<real_type>(1e-4));
auto fac = fac11 / std::pow(fac_old, control_.beta);
Expand All @@ -209,11 +224,12 @@ class solver {
template <typename Rhs>
void run(real_type t, real_type t_end, real_type* y,
zero_every_type<real_type>& zero_every,
const events_type<real_type>& events,
ode::internals<real_type>& internals, Rhs rhs) {
if (control_.critical_times.empty()) {
while (t < t_end) {
apply_zero_every(t, y, zero_every, internals);
t = step(t, t_end, y, internals, rhs);
t = step(t, t_end, y, events, internals, rhs);
}
} else {
// Slightly more complex loop which ensures we never integrate
Expand All @@ -224,7 +240,7 @@ class solver {
auto t_end_i = (tc == tc_end || *tc >= t_end) ? t_end : *tc;
while (t < t_end) {
apply_zero_every(t, y, zero_every, internals);
t = step(t, t_end_i, y, internals, rhs);
t = step(t, t_end_i, y, events, internals, rhs);
if (t >= t_end_i && t < t_end) {
++tc;
t_end_i = (tc == tc_end || *tc >= t_end) ? t_end : *tc;
Expand Down Expand Up @@ -295,8 +311,6 @@ class solver {

private:
void update_interpolation(real_type t, real_type h, real_type* y, ode::internals<real_type>& internals) {
// We might want to only do this bit if we'll actually use the
// history, but it's pretty cheap really.
internals.last.t0 = t;
internals.last.t1 = t + h;
internals.last.h = h;
Expand All @@ -315,6 +329,47 @@ class solver {
std::copy_n(y_next_.begin(), n_variables_, y);
}

real_type apply_events(real_type t0, real_type h, const real_type* y,
const events_type<real_type>& events,
ode::internals<real_type>& internals) {
size_t idx_first = events.size();
real_type t1 = t0 + h;
real_type sign = 0;

for (size_t idx_event = 0; idx_event < events.size(); ++idx_event) {
const auto& e = events[idx_event];
// Use y_stiff as temporary space here, it's only used
// transiently and within the step
real_type * y_t = y_stiff_.data();
auto fn = [&](auto t) {
internals.last.interpolate(t, e.index, y_t);
return e.test(t, y_t);
};
const auto f_t0 = fn(t0);
const auto f_t1 = fn(t1);
if (is_root(f_t0, f_t1, e.root)) {
// These probably should move into the ode control, but there
// should really be any great need to change them, and the
// interpolation is expected to be quite fast and accurate.
constexpr real_type eps = 1e-6;
constexpr size_t steps = 100;
auto root = lostturnip::find_result<real_type>(fn, t0, t1, eps, steps);
idx_first = idx_event;
t1 = root.x;
sign = f_t0 < 0 ? 1 : -1;
}
if (idx_first < events.size()) {
internals.last.interpolate(t1, y_next_.data());
events[idx_first].action(t1, sign, y_next_.data());
// We need to modify the history here so that search will find
// the right point.
internals.last.t1 = t1;
internals.events.push_back({t1, idx_first, sign});
}
}
return t1;
}

void apply_zero_every(real_type t, real_type* y,
const zero_every_type<real_type>& zero_every,
ode::internals<real_type>& internals) {
Expand Down
7 changes: 5 additions & 2 deletions inst/include/dust2/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>
#include <dust2/continuous/control.hpp>
#include <dust2/continuous/delays.hpp>
#include <dust2/continuous/events.hpp>
#include <dust2/continuous/solver.hpp>
#include <dust2/errors.hpp>
#include <dust2/internals.hpp>
Expand Down Expand Up @@ -67,6 +68,7 @@ class dust_continuous {
errors_(n_particles_total_),
rng_(n_particles_total_, seed, deterministic),
delays_(do_delays<T>(shared_)),
events_(do_events<T>(shared_, internal_)),
solver_(n_groups_ * n_threads_, {n_state_ode_, control_}),
output_is_current_(n_groups_),
requires_initialise_(n_groups_, true) {
Expand All @@ -89,7 +91,7 @@ class dust_continuous {
const auto offset = k * n_state_;
real_type * y = state_data + offset;
try {
solver_[i].run(time_, time, y, zero_every_[group],
solver_[i].run(time_, time, y, zero_every_[group], events_[i],
ode_internals_[k],
rhs_(particle, group, thread));
} catch (std::exception const& e) {
Expand Down Expand Up @@ -130,7 +132,7 @@ class dust_continuous {
for (size_t step = 0; step < n_steps; ++step) {
const real_type t0 = t1;
t1 = (step == n_steps - 1) ? time : time_ + step * dt_;
solver_[i].run(t0, t1, y, zero_every_[group],
solver_[i].run(t0, t1, y, zero_every_[group], events_[i],
ode_internals_[k],
rhs_(particle, group, thread));
std::copy_n(y, n_state_ode_, y_other);
Expand Down Expand Up @@ -389,6 +391,7 @@ class dust_continuous {
dust2::utils::errors errors_;
monty::random::prng<rng_state_type> rng_;
std::vector<ode::delays<real_type>> delays_;
std::vector<ode::events_type<real_type>> events_;
std::vector<ode::solver<real_type>> solver_;
std::vector<bool> output_is_current_;
std::vector<bool> requires_initialise_;
Expand Down
35 changes: 35 additions & 0 deletions inst/include/dust2/properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <dust2/packing.hpp>
#include <dust2/continuous/delays.hpp>
#include <dust2/continuous/events.hpp>
#include <dust2/zero.hpp>
#include <vector>

Expand Down Expand Up @@ -44,6 +45,11 @@ struct test_has_delays: std::false_type {};
template <class T>
struct test_has_delays<T, std::void_t<decltype(T::delays)>>: std::true_type {};

template <class T, class = void>
struct test_has_events: std::false_type {};
template <class T>
struct test_has_events<T, std::void_t<decltype(T::events)>>: std::true_type {};

// These test that the signature of rhs and output consume the delays
// argument. Not especially lovely to read!
template <typename T>
Expand Down Expand Up @@ -90,6 +96,7 @@ struct properties {
// Because of the above these are now actual numbers rather than types; we may make this change everywhere...
static constexpr bool rhs_uses_delays = internals::test_rhs_uses_delays<T>();
static constexpr bool output_uses_delays = internals::test_output_uses_delays<T>();
using has_events = internals::test_has_events<T>;
};

// wrappers around some uses of member functions that may or may not
Expand Down Expand Up @@ -161,4 +168,32 @@ auto do_delays(const std::vector<typename T::shared_state>& shared) {
return std::vector<ode::delays<real_type>>(shared.size(), dust2::ode::delays<real_type>{{}});
}

template <typename T, typename std::enable_if<properties<T>::has_events::value, T>::type* = nullptr>
auto do_events(const std::vector<typename T::shared_state>& shared,
std::vector<typename T::internal_state>& internal) {
using real_type = typename T::real_type;
std::vector<ode::events_type<real_type>> ret;

const auto n_groups = shared.size();
const auto n_threads = internal.size();
ret.reserve(n_threads * n_groups);
auto iter = internal.begin();
for (size_t i = 0; i < n_threads; ++i) {
for (auto& s : shared) {
ret.push_back(T::events(s, *iter));
++iter;
}
}
return ret;
}

template <typename T, typename std::enable_if<!properties<T>::has_events::value, T>::type* = nullptr>
auto do_events(const std::vector<typename T::shared_state>& shared,
std::vector<typename T::internal_state>& internal) {
using real_type = typename T::real_type;
const auto len = shared.size() * internal.size();
const auto empty = dust2::ode::events_type<real_type>{{}};
return std::vector<ode::events_type<real_type>>(len, empty);
}

}
19 changes: 18 additions & 1 deletion inst/include/dust2/r/continuous/system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals<real_type>& internals,
"n_steps_accepted"_nm = cpp11::as_sexp(internals.n_steps_accepted),
"n_steps_rejected"_nm = cpp11::as_sexp(internals.n_steps_rejected),
"coefficients"_nm = R_NilValue,
"history"_nm = R_NilValue};
"history"_nm = R_NilValue,
"events"_nm = R_NilValue};

if (include_coefficients) {
auto r_coef = cpp11::writable::doubles_matrix<>(internals.last.c1.size(), 5);
auto coef = REAL(r_coef);
Expand Down Expand Up @@ -107,6 +109,21 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals<real_type>& internals,
"coefficients"_nm = std::move(r_history_coef)};
ret["history"] = cpp11::as_sexp(r_history);
}
if (!internals.events.empty()) {
const auto n_events = internals.events.size();
auto r_event_time = cpp11::writable::doubles(n_events);
auto r_event_index = cpp11::writable::integers(n_events);
auto r_event_sign = cpp11::writable::doubles(n_events);
for (size_t i = 0; i < n_events; ++i) {
r_event_time[i] = internals.events[i].time;
r_event_index[i] = static_cast<int>(internals.events[i].index);
richfitz marked this conversation as resolved.
Show resolved Hide resolved
r_event_sign[i] = internals.events[i].sign;
}
auto r_events = cpp11::writable::list{"time"_nm = std::move(r_event_time),
"index"_nm = std::move(r_event_index),
"sign"_nm = std::move(r_event_sign)};
ret["events"] = cpp11::as_sexp(r_events);
}
return ret;
}

Expand Down
Loading
Loading