From 3b0a6abb23f871070c23ba1751fc524086b519cb Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 25 Nov 2024 08:16:39 +0000 Subject: [PATCH 01/19] Basic support for events --- inst/include/dust2/common.hpp | 1 + inst/include/dust2/continuous/events.hpp | 43 ++++++ inst/include/dust2/continuous/history.hpp | 9 ++ inst/include/dust2/continuous/solver.hpp | 59 +++++++- inst/include/dust2/continuous/system.hpp | 7 +- inst/include/dust2/properties.hpp | 24 +++ inst/include/lostturnip.hpp | 169 ++++++++++++++++++++++ tests/testthat/examples/event.cpp | 68 +++++++++ 8 files changed, 374 insertions(+), 6 deletions(-) create mode 100644 inst/include/dust2/continuous/events.hpp create mode 100644 inst/include/lostturnip.hpp create mode 100644 tests/testthat/examples/event.cpp diff --git a/inst/include/dust2/common.hpp b/inst/include/dust2/common.hpp index 04c87007..18cb1b41 100644 --- a/inst/include/dust2/common.hpp +++ b/inst/include/dust2/common.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include // In an odd place, we might update that later too diff --git a/inst/include/dust2/continuous/events.hpp b/inst/include/dust2/continuous/events.hpp new file mode 100644 index 00000000..8bafc29e --- /dev/null +++ b/inst/include/dust2/continuous/events.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include +#include + +namespace dust2 { +namespace ode { + +// Do we detect an event when passing through the root as we increase +// (negative to positive), decrease (positive to negative) or both: +enum class root_type { + both, + increase, + decrease +}; + +// The actual logic for the above +template +bool is_root(const real_type a, const real_type b, const root_type& root) { + switch(root) { + case root_type::both: + return a * b < 0; + case root_type::increase: + return a < 0 && b > 0; + case root_type::decrease: + return a > 0 && b < 0; + } + return false; +} + +template +struct event { + size_t index; + real_type value; + std::function action; // time, y + root_type root = root_type::both; +}; + +template +using events_type = std::vector>; + +} +} diff --git a/inst/include/dust2/continuous/history.hpp b/inst/include/dust2/continuous/history.hpp index de21f4e1..5402754e 100644 --- a/inst/include/dust2/continuous/history.hpp +++ b/inst/include/dust2/continuous/history.hpp @@ -66,6 +66,15 @@ struct history_step { } } + real_type interpolate(real_type time, size_t i) const { + // Consider special case for u or v == 0 + // u == 0: return c1[i] + // v == 0: return c1[i] + c2[i] + const auto u = (time - t0) / h; + const auto v = 1 - u; + return c1[i] + u * (c2[i] + v * (c3[i] + u * (c4[i] + v * c5[i]))); + } + history_step subset(std::vector index) const { return history_step(t0, t1, diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index e3397adb..118f2dae 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -4,6 +4,8 @@ #include #include #include +#include + #include #include #include @@ -147,6 +149,7 @@ class solver { // Take a single step template real_type step(real_type t, real_type t_end, real_type* y, + const events_type& events, ode::internals& internals, Rhs rhs) { auto success = false; auto reject = false; @@ -178,6 +181,16 @@ class solver { if (err <= 1) { success = true; update_interpolation(t, h, y, internals); + // If we end up using a std::array of these, we can make this + // constexpr, which is nice. + if (!events.empty()) { + const auto t_next = apply_events(t, h, y, events, internals); + if (t_next < t + h) { + truncated = true; + h = t_next - t; + rhs(t_next, y_next_.data(), k2_.data()); + } + } accept(t, h, y, internals); internals.n_steps_accepted++; if (control_.debug_record_step_times) { @@ -209,11 +222,12 @@ class solver { template void run(real_type t, real_type t_end, real_type* y, zero_every_type& zero_every, + const events_type& events, ode::internals& internals, Rhs rhs) { if (control_.critical_times.empty()) { while (t < t_end) { apply_zero_every(t, y, zero_every, internals); - t = step(t, t_end, y, internals, rhs); + t = step(t, t_end, y, events, internals, rhs); } } else { // Slightly more complex loop which ensures we never integrate @@ -224,7 +238,7 @@ class solver { auto t_end_i = (tc == tc_end || *tc >= t_end) ? t_end : *tc; while (t < t_end) { apply_zero_every(t, y, zero_every, internals); - t = step(t, t_end_i, y, internals, rhs); + t = step(t, t_end_i, y, events, internals, rhs); if (t >= t_end_i && t < t_end) { ++tc; t_end_i = (tc == tc_end || *tc >= t_end) ? t_end : *tc; @@ -295,8 +309,6 @@ class solver { private: void update_interpolation(real_type t, real_type h, real_type* y, ode::internals& internals) { - // We might want to only do this bit if we'll actually use the - // history, but it's pretty cheap really. internals.last.t0 = t; internals.last.t1 = t + h; internals.last.h = h; @@ -315,6 +327,45 @@ class solver { std::copy_n(y_next_.begin(), n_variables_, y); } + real_type apply_events(real_type t0, real_type h, const real_type* y, + const events_type& events, + ode::internals& internals) { + bool found = false; + size_t idx_first = 0; + real_type t1 = t0 + h; + + for (size_t idx_event = 0; idx_event < events.size(); ++idx_event) { + const auto& e = events[idx_event]; + const auto idx_state = e.index; + const auto value = e.value; + auto fn = [&](auto t) { + return internals.last.interpolate(t, idx_state) - value; + }; + if (is_root(fn(t0), fn(t1), e.root)) { + // These probably should move into the ode control, but there + // should really be any great need to change them, and the + // interpolation is expected to be quite fast and accurate. + constexpr real_type eps = 1e-6; + constexpr size_t steps = 100; + auto root = lostturnip::find_result(fn, t0, t1, eps, steps); + found = true; + idx_first = idx_event; + t1 = root.x; + } + if (found) { + internals.last.interpolate(t1, y_next_.data()); + // These actions probably will have needed to bind + // shared/internal eventually, that will be done elsewhere. + events[idx_first].action(t1, y_next_.data()); + // We need to modify the history here so that search will find + // the right point. + internals.last.t1 = t1; + // TODO: log event! + } + } + return t1; + } + void apply_zero_every(real_type t, real_type* y, const zero_every_type& zero_every, ode::internals& internals) { diff --git a/inst/include/dust2/continuous/system.hpp b/inst/include/dust2/continuous/system.hpp index c9021fd3..4ead6719 100644 --- a/inst/include/dust2/continuous/system.hpp +++ b/inst/include/dust2/continuous/system.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -67,6 +68,7 @@ class dust_continuous { errors_(n_particles_total_), rng_(n_particles_total_, seed, deterministic), delays_(do_delays(shared_)), + events_(do_events(shared_)), solver_(n_groups_ * n_threads_, {n_state_ode_, control_}), output_is_current_(n_groups_), requires_initialise_(n_groups_, true) { @@ -89,7 +91,7 @@ class dust_continuous { const auto offset = k * n_state_; real_type * y = state_data + offset; try { - solver_[i].run(time_, time, y, zero_every_[group], + solver_[i].run(time_, time, y, zero_every_[group], events_[group], ode_internals_[k], rhs_(particle, group, thread)); } catch (std::exception const& e) { @@ -130,7 +132,7 @@ class dust_continuous { for (size_t step = 0; step < n_steps; ++step) { const real_type t0 = t1; t1 = (step == n_steps - 1) ? time : time_ + step * dt_; - solver_[i].run(t0, t1, y, zero_every_[group], + solver_[i].run(t0, t1, y, zero_every_[group], events_[group], ode_internals_[k], rhs_(particle, group, thread)); std::copy_n(y, n_state_ode_, y_other); @@ -389,6 +391,7 @@ class dust_continuous { dust2::utils::errors errors_; monty::random::prng rng_; std::vector> delays_; + std::vector> events_; std::vector> solver_; std::vector output_is_current_; std::vector requires_initialise_; diff --git a/inst/include/dust2/properties.hpp b/inst/include/dust2/properties.hpp index 61eb37f9..e356718d 100644 --- a/inst/include/dust2/properties.hpp +++ b/inst/include/dust2/properties.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -44,6 +45,11 @@ struct test_has_delays: std::false_type {}; template struct test_has_delays>: std::true_type {}; +template +struct test_has_events: std::false_type {}; +template +struct test_has_events>: std::true_type {}; + // These test that the signature of rhs and output consume the delays // argument. Not especially lovely to read! template @@ -90,6 +96,7 @@ struct properties { // Because of the above these are now actual numbers rather than types; we may make this change everywhere... static constexpr bool rhs_uses_delays = internals::test_rhs_uses_delays(); static constexpr bool output_uses_delays = internals::test_output_uses_delays(); + using has_events = internals::test_has_events; }; // wrappers around some uses of member functions that may or may not @@ -161,4 +168,21 @@ auto do_delays(const std::vector& shared) { return std::vector>(shared.size(), dust2::ode::delays{{}}); } +template ::has_events::value, T>::type* = nullptr> +auto do_events(const std::vector& shared) { + using real_type = typename T::real_type; + std::vector> ret; + ret.reserve(shared.size()); + for (size_t i = 0; i < shared.size(); ++i) { + ret.push_back(T::events(shared[i])); + } + return ret; +} + +template ::has_events::value, T>::type* = nullptr> +auto do_events(const std::vector& shared) { + using real_type = typename T::real_type; + return std::vector>(shared.size(), dust2::ode::events_type{{}}); +} + } diff --git a/inst/include/lostturnip.hpp b/inst/include/lostturnip.hpp new file mode 100644 index 00000000..3e1ea5d8 --- /dev/null +++ b/inst/include/lostturnip.hpp @@ -0,0 +1,169 @@ +#pragma once +#include +#include +#include + +namespace lostturnip { + +// Declaring these here, rather than within the find_result, as +// otherwise we get a compiler warning about using experimental cuda +// features. It will be equivalent though, but does require C++14. +namespace { +template +constexpr real_type na = std::numeric_limits::quiet_NaN(); + +template +constexpr real_type eps = std::numeric_limits::epsilon(); +} + +template +struct result { + real_type x; + real_type fx; + int iterations; + bool converged; +}; + +// From zeroin.c, in brent.shar +template +#ifdef __NVCC__ +__host__ __device__ +#endif +result find_result(F f, real_type a, real_type b, + real_type tol, int max_iterations) { + real_type fa = f(a); + real_type fb = f(b); + int iterations = 0; + bool converged = false; + + if (fa == 0) { + b = a; + fb = fa; + converged = true; + } else if (fb == 0) { + converged = true; + } else if (fa * fb > 0) { + // Same sign; can't find root with this: + b = na; + fb = na; + converged = false; + } else { + real_type c = a; + real_type fc = fa; // c = a, f(c) = f(a) + + for (; iterations < max_iterations; ++iterations) { // Main iteration loop + // Distance from the last but one to the last approximation + const real_type prev_step = b - a; + + // Interpolation step is calculated in the form p/q; division + // operations is dlayed until the last moment + real_type p; + real_type q; + + if (std::abs(fc) < std::abs(fb)) { + // Swap data for b to be the best approximation + a = b; + b = c; + c = a; + fa = fb; + fb = fc; + fc = fa; + } + + // Actual tolerance + const real_type tol_act = 2 * eps * std::abs(b) + tol / 2; + // Step at this iteration + real_type new_step = (c - b) / 2; + + if (std::abs(new_step) <= tol_act || fb == 0) { + // Acceptable approximation is found + converged = true; + break; + } + + // increase readability below, avoids many repeated static casts + const real_type one = 1; + + // Decide if the interpolation can be tried + // + // If prev_step was large enough and was in true direction, then + // interpolation can be tried + if (std::abs(prev_step) >= tol_act && std::abs(fa) > std::abs(fb)) { + // interpolation + const real_type cb = c - b; + if (a == c) { + // If we have only two distinct points linear interpolation + // can only be applied + const real_type t1 = fb / fa; + p = cb * t1; + q = one - t1; + } else { + // Quadric inverse interpolation + q = fa / fc; + const real_type t1 = fb / fc; + const real_type t2 = fb / fa; + p = t2 * (cb * q * (q - t1) - (b - a) * (t1 - one)); + q = (q - one) * (t1 - one) * (t2 - one); + } + if (p > 0) { + // p was calculated with the opposite sign; make p positive + // and assign possible minus to q + q = -q; + } else { + p = -p; + } + + // If b + p / q falls in [b, c] and isn't too large it is + // accepted + // + // If p / q is too large then the bissection procedure can + // reduce [b,c] range to more extent + if (p < (static_cast(0.75) * cb * q - std::abs(tol_act * q) / 2) && + p < std::abs(prev_step * q / 2)) { + new_step = p / q; + } + } + + // Adjust the step to be not less than tolerance + if (std::abs(new_step) < tol_act) { + new_step = std::copysign(tol_act, new_step); + } + + // Save the previous approximation + a = b; + fa = fb; + // Do step to a new approximation + b += new_step; + fb = f(b); + if ((fb > 0 && fc > 0) || (fb < 0 && fc < 0)) { + // Adjust c for it to have a sign opposite to that of b + c = a; fc = fa; + } + } + } + +#ifdef __CUDA_ARCH__ + __syncwarp(); +#endif + return result{b, fb, iterations, converged}; +} + +template +#ifdef __NVCC__ +__host__ __device__ +#endif +real_type find(F f, real_type a, real_type b, + real_type tol, int max_iterations) { + const auto result = find_result(f, a, b, tol, max_iterations); + if (!result.converged) { +#ifdef __CUDA_ARCH__ + printf("some error\n"); + __trap(); +#else + throw std::runtime_error("some error"); +#endif + } + return result.x; +} + +} diff --git a/tests/testthat/examples/event.cpp b/tests/testthat/examples/event.cpp new file mode 100644 index 00000000..93e4eef5 --- /dev/null +++ b/tests/testthat/examples/event.cpp @@ -0,0 +1,68 @@ +#include + +// [[dust2::class(bounce)]] +// [[dust2::time_type(continuous)]] +// [[dust2::parameter(height, rank = 0)]] +// [[dust2::parameter(velocity, rank = 0)]] +class bounce { +public: + bounce() = delete; + + using real_type = double; + + struct shared_state { + real_type g; + real_type height; + real_type velocity; + real_type damp; + }; + + struct internal_state {}; + + using rng_state_type = monty::random::generator; + + static dust2::packing packing_state(const shared_state& shared) { + return dust2::packing{{"height", {}}, {"velocity", {}}}; + } + + static void initial(real_type time, + const shared_state& shared, + internal_state& internal, + rng_state_type& rng_state, + real_type * state) { + state[0] = shared.height; + state[1] = shared.velocity; + } + + static void rhs(real_type time, + const real_type * state, + const shared_state& shared, + internal_state& internal, + real_type * state_deriv) { + state_deriv[0] = state[1]; + state_deriv[1] = -shared.g; + } + + static shared_state build_shared(cpp11::list pars) { + const real_type g = 9.81; + const real_type height = dust2::r::read_real(pars, "height", 0); + const real_type velocity = dust2::r::read_real(pars, "velocity", 10); + const real_type damp = dust2::r::read_real(pars, "damp", 0.9); + return shared_state{g, height, velocity, damp}; + } + + static void update_shared(cpp11::list pars, shared_state& shared) { + shared.damp = dust2::r::read_real(pars, "damp", shared.damp); + } + + static auto events(const shared_state& shared) { + // We can capture 'shared' by scope here, but not internal; that + // would require a second phase of binding, which would need to be + // done by the system. + dust2::ode::event e{0, 0, [&](double t, double* y) { + y[0] = 0; + y[1] = -shared.damp * y[1]; + }}; + return dust2::ode::events_type({e}); + } +}; From 56f69deb4dbe0384d20e076ed9ba17a4176949f9 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 25 Nov 2024 08:40:05 +0000 Subject: [PATCH 02/19] Add basic test --- DESCRIPTION | 1 + tests/testthat/examples/event.cpp | 2 +- tests/testthat/helper-dust.R | 22 ++++++++++++++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index b205b64a..6257b20b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -27,6 +27,7 @@ Suggests: callr, cpp11, decor, + deSolve, fs, glue, knitr, diff --git a/tests/testthat/examples/event.cpp b/tests/testthat/examples/event.cpp index 93e4eef5..c580b179 100644 --- a/tests/testthat/examples/event.cpp +++ b/tests/testthat/examples/event.cpp @@ -44,7 +44,7 @@ class bounce { } static shared_state build_shared(cpp11::list pars) { - const real_type g = 9.81; + const real_type g = 9.8; const real_type height = dust2::r::read_real(pars, "height", 0); const real_type velocity = dust2::r::read_real(pars, "velocity", 10); const real_type damp = dust2::r::read_real(pars, "damp", 0.9); diff --git a/tests/testthat/helper-dust.R b/tests/testthat/helper-dust.R index 28244ec0..7c740306 100644 --- a/tests/testthat/helper-dust.R +++ b/tests/testthat/helper-dust.R @@ -83,3 +83,25 @@ local_sir_generator <- function() { options(dust.testing.local_sir_generator = gen) gen } + + +example_bounce <- function(t) { + skip_if_not_installed("deSolve") + ball <- function(t, y, parms) { + dy1 <- y[2] + dy2 <- -9.8 + list(c(dy1, dy2)) + } + yini <- c(height = 0, velocity = 10) + rootfunc <- function(t, y, parms) { + y[1] + } + eventfunc <- function(t, y, parms) { + y[1] <- 0 + y[2] <- -0.9 * y[2] + y + } + deSolve::ode(times = t, y = yini, func = ball, + parms = NULL, rootfunc = rootfunc, + events = list(func = eventfunc, root = TRUE)) +} From 37bdcaa136da807bdd76af9ce05a04740fff48f3 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 25 Nov 2024 09:46:55 +0000 Subject: [PATCH 03/19] Make headers self-sufficient --- inst/include/dust2/continuous/solver.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index 118f2dae..68f5e912 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace dust2 { From 8c1ff03ff048afa4f9850c6863de7269279ae55b Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 25 Nov 2024 09:52:47 +0000 Subject: [PATCH 04/19] Include test file --- tests/testthat/test-zzz-events.R | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 tests/testthat/test-zzz-events.R diff --git a/tests/testthat/test-zzz-events.R b/tests/testthat/test-zzz-events.R new file mode 100644 index 00000000..75f5a59e --- /dev/null +++ b/tests/testthat/test-zzz-events.R @@ -0,0 +1,16 @@ +test_that("can run system with roots and events", { + gen <- dust_compile("examples/event.cpp", quiet = TRUE, debug = TRUE) + + sys <- dust_system_create(gen) + dust_system_set_state_initial(sys) + t <- seq(0, 6, length.out = 500) + y <- dust_system_simulate(sys, t) + + ## This is realy not great, but at this point I don't know who is + ## worse. Qualitatively we're about right though and I need to go + ## through and compare with an analytic solution as here we've got + ## two different approximations and a nonlinear system that is + ## acumulating error. + cmp <- example_bounce(t) + expect_equal(y[1, ], cmp[, 2], tolerance = 1e-2) +}) From c72ef2e90b15bf9ee1dded57290a556202b061dd Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 25 Nov 2024 10:12:09 +0000 Subject: [PATCH 05/19] Save details about events --- R/interface.R | 5 ++++- inst/include/dust2/continuous/events.hpp | 14 ++++++++++++-- inst/include/dust2/continuous/solver.hpp | 9 +++++++-- inst/include/dust2/r/continuous/system.hpp | 19 ++++++++++++++++++- tests/testthat/test-zzz-events.R | 3 +++ 5 files changed, 44 insertions(+), 6 deletions(-) diff --git a/R/interface.R b/R/interface.R index 9ca7160b..4f65338d 100644 --- a/R/interface.R +++ b/R/interface.R @@ -485,7 +485,10 @@ dust_system_internals <- function(sys, include_coefficients = FALSE, error = vnapply(dat, "[[", "error"), n_steps = viapply(dat, "[[", "n_steps"), n_steps_accepted = viapply(dat, "[[", "n_steps_accepted"), - n_steps_rejected = viapply(dat, "[[", "n_steps_rejected")) + n_steps_rejected = viapply(dat, "[[", "n_steps_rejected"), + events = I(lapply(dat, function(x) { + if (is.null(x$events)) NULL else as.data.frame(x$events) + }))) if (include_coefficients) { ret$coefficients <- I(lapply(dat, "[[", "coefficients")) } diff --git a/inst/include/dust2/continuous/events.hpp b/inst/include/dust2/continuous/events.hpp index 8bafc29e..c7e7c315 100644 --- a/inst/include/dust2/continuous/events.hpp +++ b/inst/include/dust2/continuous/events.hpp @@ -28,7 +28,7 @@ bool is_root(const real_type a, const real_type b, const root_type& root) { return false; } -template +template struct event { size_t index; real_type value; @@ -36,8 +36,18 @@ struct event { root_type root = root_type::both; }; -template +template using events_type = std::vector>; +template +struct event_history_element { + real_type time; + size_t index; + real_type sign; +}; + +template +using event_history = std::vector>; + } } diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index 68f5e912..c77187b1 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -34,6 +34,7 @@ template struct internals { history_step last; history history_values; + event_history events; std::vector dydt; std::vector step_times; @@ -334,6 +335,7 @@ class solver { bool found = false; size_t idx_first = 0; real_type t1 = t0 + h; + real_type sign = 0; for (size_t idx_event = 0; idx_event < events.size(); ++idx_event) { const auto& e = events[idx_event]; @@ -342,7 +344,9 @@ class solver { auto fn = [&](auto t) { return internals.last.interpolate(t, idx_state) - value; }; - if (is_root(fn(t0), fn(t1), e.root)) { + const auto f_t0 = fn(t0); + const auto f_t1 = fn(t1); + if (is_root(f_t0, f_t1, e.root)) { // These probably should move into the ode control, but there // should really be any great need to change them, and the // interpolation is expected to be quite fast and accurate. @@ -352,6 +356,7 @@ class solver { found = true; idx_first = idx_event; t1 = root.x; + sign = f_t0 < 0 ? 1 : -1; } if (found) { internals.last.interpolate(t1, y_next_.data()); @@ -361,7 +366,7 @@ class solver { // We need to modify the history here so that search will find // the right point. internals.last.t1 = t1; - // TODO: log event! + internals.events.push_back({t1, idx_first, sign}); } } return t1; diff --git a/inst/include/dust2/r/continuous/system.hpp b/inst/include/dust2/r/continuous/system.hpp index 17680651..01bbd6a1 100644 --- a/inst/include/dust2/r/continuous/system.hpp +++ b/inst/include/dust2/r/continuous/system.hpp @@ -67,7 +67,9 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals& internals, "n_steps_accepted"_nm = cpp11::as_sexp(internals.n_steps_accepted), "n_steps_rejected"_nm = cpp11::as_sexp(internals.n_steps_rejected), "coefficients"_nm = R_NilValue, - "history"_nm = R_NilValue}; + "history"_nm = R_NilValue, + "events"_nm = R_NilValue}; + if (include_coefficients) { auto r_coef = cpp11::writable::doubles_matrix<>(internals.last.c1.size(), 5); auto coef = REAL(r_coef); @@ -107,6 +109,21 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals& internals, "coefficients"_nm = std::move(r_history_coef)}; ret["history"] = cpp11::as_sexp(r_history); } + if (!internals.events.empty()) { + const auto n_events = internals.events.size(); + auto r_event_time = cpp11::writable::doubles(n_events); + auto r_event_index = cpp11::writable::integers(n_events); + auto r_event_sign = cpp11::writable::doubles(n_events); + for (size_t i = 0; i < n_events; ++i) { + r_event_time[i] = internals.events[i].time; + r_event_index[i] = static_cast(internals.events[i].index); + r_event_sign[i] = internals.events[i].sign; + } + auto r_events = cpp11::writable::list{"time"_nm = std::move(r_event_time), + "index"_nm = std::move(r_event_index), + "sign"_nm = std::move(r_event_sign)}; + ret["events"] = cpp11::as_sexp(r_events); + } return ret; } diff --git a/tests/testthat/test-zzz-events.R b/tests/testthat/test-zzz-events.R index 75f5a59e..414fc5ad 100644 --- a/tests/testthat/test-zzz-events.R +++ b/tests/testthat/test-zzz-events.R @@ -6,6 +6,9 @@ test_that("can run system with roots and events", { t <- seq(0, 6, length.out = 500) y <- dust_system_simulate(sys, t) + info <- dust_system_internals(sys) + expect_equal(nrow(info$events[[1]]), 3) + ## This is realy not great, but at this point I don't know who is ## worse. Qualitatively we're about right though and I need to go ## through and compare with an analytic solution as here we've got From 498e8345b959ff630b87a3c8dc043c9d656e1781 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 25 Nov 2024 10:45:12 +0000 Subject: [PATCH 06/19] Start generalising events --- inst/include/dust2/continuous/events.hpp | 16 +++++++++++++--- inst/include/dust2/continuous/solver.hpp | 13 ++++++------- tests/testthat/examples/event.cpp | 5 +++-- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/inst/include/dust2/continuous/events.hpp b/inst/include/dust2/continuous/events.hpp index c7e7c315..0ab4c954 100644 --- a/inst/include/dust2/continuous/events.hpp +++ b/inst/include/dust2/continuous/events.hpp @@ -30,10 +30,20 @@ bool is_root(const real_type a, const real_type b, const root_type& root) { template struct event { - size_t index; - real_type value; - std::function action; // time, y + using test_type = std::function; + using action_type = std::function; + std::vector index; root_type root = root_type::both; + test_type test; + action_type action; + + event(const std::vector& index, test_type test, action_type action, root_type root = root_type::both) : + index(index), root(root), test(test), action(action) { + } + + event(size_t index, action_type action, root_type root = root_type::both) : + event({index}, [](real_type t, const real_type* y) { return y[0]; }, action, root) { + } }; template diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index c77187b1..7c1898ef 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -332,17 +332,17 @@ class solver { real_type apply_events(real_type t0, real_type h, const real_type* y, const events_type& events, ode::internals& internals) { - bool found = false; - size_t idx_first = 0; + size_t idx_first = events.size(); real_type t1 = t0 + h; real_type sign = 0; for (size_t idx_event = 0; idx_event < events.size(); ++idx_event) { const auto& e = events[idx_event]; - const auto idx_state = e.index; - const auto value = e.value; + // temporary, will update soon - we can use y_stiff I think? + std::vector y_tmp(e.index.size()); auto fn = [&](auto t) { - return internals.last.interpolate(t, idx_state) - value; + internals.last.interpolate(t, e.index, y_tmp.begin()); + return e.test(t, y_tmp.data()); }; const auto f_t0 = fn(t0); const auto f_t1 = fn(t1); @@ -353,12 +353,11 @@ class solver { constexpr real_type eps = 1e-6; constexpr size_t steps = 100; auto root = lostturnip::find_result(fn, t0, t1, eps, steps); - found = true; idx_first = idx_event; t1 = root.x; sign = f_t0 < 0 ? 1 : -1; } - if (found) { + if (idx_first < events.size()) { internals.last.interpolate(t1, y_next_.data()); // These actions probably will have needed to bind // shared/internal eventually, that will be done elsewhere. diff --git a/tests/testthat/examples/event.cpp b/tests/testthat/examples/event.cpp index c580b179..88a185e4 100644 --- a/tests/testthat/examples/event.cpp +++ b/tests/testthat/examples/event.cpp @@ -59,10 +59,11 @@ class bounce { // We can capture 'shared' by scope here, but not internal; that // would require a second phase of binding, which would need to be // done by the system. - dust2::ode::event e{0, 0, [&](double t, double* y) { + auto action = [&](double t, double* y) { y[0] = 0; y[1] = -shared.damp * y[1]; - }}; + }; + dust2::ode::event e(0, action); return dust2::ode::events_type({e}); } }; From 266a30c3f5077a7dc190bbf1a87cdb384dfc170f Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 09:38:53 +0000 Subject: [PATCH 07/19] Allow access to shared/internal --- inst/include/dust2/continuous/system.hpp | 2 +- inst/include/dust2/properties.hpp | 23 +++++++++++++++++------ tests/testthat/examples/event.cpp | 2 +- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/inst/include/dust2/continuous/system.hpp b/inst/include/dust2/continuous/system.hpp index 4ead6719..59e4ec13 100644 --- a/inst/include/dust2/continuous/system.hpp +++ b/inst/include/dust2/continuous/system.hpp @@ -68,7 +68,7 @@ class dust_continuous { errors_(n_particles_total_), rng_(n_particles_total_, seed, deterministic), delays_(do_delays(shared_)), - events_(do_events(shared_)), + events_(do_events(shared_, internal_)), solver_(n_groups_ * n_threads_, {n_state_ode_, control_}), output_is_current_(n_groups_), requires_initialise_(n_groups_, true) { diff --git a/inst/include/dust2/properties.hpp b/inst/include/dust2/properties.hpp index e356718d..c4dbe6f9 100644 --- a/inst/include/dust2/properties.hpp +++ b/inst/include/dust2/properties.hpp @@ -169,20 +169,31 @@ auto do_delays(const std::vector& shared) { } template ::has_events::value, T>::type* = nullptr> -auto do_events(const std::vector& shared) { +auto do_events(const std::vector& shared, + std::vector& internal) { using real_type = typename T::real_type; std::vector> ret; - ret.reserve(shared.size()); - for (size_t i = 0; i < shared.size(); ++i) { - ret.push_back(T::events(shared[i])); + + const auto n_groups = shared.size(); + const auto n_threads = internal.size(); + ret.reserve(n_threads * n_groups); + auto iter = internal.begin(); + for (size_t i = 0; i < n_threads; ++i) { + for (auto& s : shared) { + ret.push_back(T::events(s, *iter)); + ++iter; + } } return ret; } template ::has_events::value, T>::type* = nullptr> -auto do_events(const std::vector& shared) { +auto do_events(const std::vector& shared, + std::vector& internal) { using real_type = typename T::real_type; - return std::vector>(shared.size(), dust2::ode::events_type{{}}); + const auto len = shared.size() * internal.size(); + const auto empty = dust2::ode::events_type{{}}; + return std::vector>(len, empty); } } diff --git a/tests/testthat/examples/event.cpp b/tests/testthat/examples/event.cpp index 88a185e4..8f0fd76c 100644 --- a/tests/testthat/examples/event.cpp +++ b/tests/testthat/examples/event.cpp @@ -55,7 +55,7 @@ class bounce { shared.damp = dust2::r::read_real(pars, "damp", shared.damp); } - static auto events(const shared_state& shared) { + static auto events(const shared_state& shared, internal_state& internal) { // We can capture 'shared' by scope here, but not internal; that // would require a second phase of binding, which would need to be // done by the system. From 2368c5f10cd6533f671eb3bf7d363f1074effe7d Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 10:48:23 +0000 Subject: [PATCH 08/19] Correct behaviour on truncation --- inst/include/dust2/continuous/solver.hpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index 7c1898ef..9597d1a4 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -156,6 +156,7 @@ class solver { auto success = false; auto reject = false; auto truncated = false; + auto event = false; auto h = internals.step_size; while (!success) { @@ -185,10 +186,12 @@ class solver { update_interpolation(t, h, y, internals); // If we end up using a std::array of these, we can make this // constexpr, which is nice. + // This would be step 21 atm. if (!events.empty()) { const auto t_next = apply_events(t, h, y, events, internals); if (t_next < t + h) { - truncated = true; + event = true; + truncated = false; h = t_next - t; rhs(t_next, y_next_.data(), k2_.data()); } @@ -199,7 +202,7 @@ class solver { internals.step_times.push_back(truncated ? t_end : t + h); } internals.save_history(); - if (!truncated) { + if (!truncated && !event) { const auto fac_old = std::max(internals.error, static_cast(1e-4)); auto fac = fac11 / std::pow(fac_old, control_.beta); From 5626e9acfd5212ddf09c0c669aad79f212d3b80b Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 10:58:00 +0000 Subject: [PATCH 09/19] Drop deSolve for an analytic solution --- DESCRIPTION | 1 - tests/testthat/helper-dust.R | 25 +++++++------------------ 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 6257b20b..b205b64a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -27,7 +27,6 @@ Suggests: callr, cpp11, decor, - deSolve, fs, glue, knitr, diff --git a/tests/testthat/helper-dust.R b/tests/testthat/helper-dust.R index 7c740306..af91195f 100644 --- a/tests/testthat/helper-dust.R +++ b/tests/testthat/helper-dust.R @@ -85,23 +85,12 @@ local_sir_generator <- function() { } -example_bounce <- function(t) { - skip_if_not_installed("deSolve") - ball <- function(t, y, parms) { - dy1 <- y[2] - dy2 <- -9.8 - list(c(dy1, dy2)) +example_bounce_analytic <- function(t, v0 = 10, damp = 0.9, g = 9.8) { + t0 <- 0 + while (last(t0) < last(t)) { + t0 <- c(t0, last(t0) + 2 * v0 * damp^(length(t0) - 1) / g) } - yini <- c(height = 0, velocity = 10) - rootfunc <- function(t, y, parms) { - y[1] - } - eventfunc <- function(t, y, parms) { - y[1] <- 0 - y[2] <- -0.9 * y[2] - y - } - deSolve::ode(times = t, y = yini, func = ball, - parms = NULL, rootfunc = rootfunc, - events = list(func = eventfunc, root = TRUE)) + i <- findInterval(t, t0) + y <- v0 * damp^(i - 1) * (t - t0[i]) - 0.5 * g * (t - t0[i])^2 + list(y = y, roots = t0[t0 > 0 & t0 < last(t)]) } From 4d3e3c8056aa849a8e666093cf3f5db23956e956 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 10:59:29 +0000 Subject: [PATCH 10/19] Fix test --- tests/testthat/test-zzz-events.R | 33 +++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/tests/testthat/test-zzz-events.R b/tests/testthat/test-zzz-events.R index 414fc5ad..ae782a6b 100644 --- a/tests/testthat/test-zzz-events.R +++ b/tests/testthat/test-zzz-events.R @@ -1,19 +1,30 @@ test_that("can run system with roots and events", { gen <- dust_compile("examples/event.cpp", quiet = TRUE, debug = TRUE) - sys <- dust_system_create(gen) + control <- dust_ode_control(debug_record_step_times = TRUE, save_history = TRUE) + sys <- dust_system_create(gen, ode_control = control) dust_system_set_state_initial(sys) - t <- seq(0, 6, length.out = 500) + + ## Use relatively few points for output here as this exacerbates + ## problems, even though the solution looks silly. + t <- seq(0, 6, length.out = 60) y <- dust_system_simulate(sys, t) + cmp <- example_bounce_analytic(t) + + info <- dust_system_internals(sys, include_history = TRUE) + + ## Find all roots: + r <- info$events[[1]] + expect_equal(nrow(r, 3)) + expect_equal(r$time, cmp$roots, tolerance = 1e-6) + expect_equal(r$index, rep(0L, 3)) + expect_equal(r$sign, rep(-1, 3)) - info <- dust_system_internals(sys) - expect_equal(nrow(info$events[[1]]), 3) + ## Stop at all roots: + h <- info$history[[1]] + expect_true(all(r$time %in% h$t0)) + expect_true(all(r$time %in% h$t1)) - ## This is realy not great, but at this point I don't know who is - ## worse. Qualitatively we're about right though and I need to go - ## through and compare with an analytic solution as here we've got - ## two different approximations and a nonlinear system that is - ## acumulating error. - cmp <- example_bounce(t) - expect_equal(y[1, ], cmp[, 2], tolerance = 1e-2) + ## Overall solution: + expect_equal(y[1, ], cmp$y, tolerance = 1e-6) }) From 2c975c93409e909b7b8acdcee0e00464f771f9bf Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 11:06:19 +0000 Subject: [PATCH 11/19] Mark lostturnip as vendored --- .covrignore | 1 + .gitattributes | 1 + 2 files changed, 2 insertions(+) diff --git a/.covrignore b/.covrignore index 635cf5df..7e497293 100644 --- a/.covrignore +++ b/.covrignore @@ -7,3 +7,4 @@ src/malaria.cpp src/sir.cpp src/sirode.cpp src/walk.cpp +inst/include/lostturnip.hpp diff --git a/.gitattributes b/.gitattributes index 3614828a..bc179e68 100644 --- a/.gitattributes +++ b/.gitattributes @@ -7,3 +7,4 @@ src/sir.cpp linguist-generated=true src/sirode.cpp linguist-generated=true src/walk.cpp linguist-generated=true R/import-*.R linguist-generated=true +inst/include/lostturnip.hpp linguist-vendored=true From 2dca76affe209f0e8b738dddd2fd2d58ad36ea04 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 11:06:32 +0000 Subject: [PATCH 12/19] Cleanup --- inst/include/dust2/continuous/solver.hpp | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index 9597d1a4..c60faaf6 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -184,9 +184,6 @@ class solver { if (err <= 1) { success = true; update_interpolation(t, h, y, internals); - // If we end up using a std::array of these, we can make this - // constexpr, which is nice. - // This would be step 21 atm. if (!events.empty()) { const auto t_next = apply_events(t, h, y, events, internals); if (t_next < t + h) { @@ -341,11 +338,12 @@ class solver { for (size_t idx_event = 0; idx_event < events.size(); ++idx_event) { const auto& e = events[idx_event]; - // temporary, will update soon - we can use y_stiff I think? - std::vector y_tmp(e.index.size()); + // Use y_stiff as temporary space here, it's only used + // transiently and within the step + real_type * y_t = y_stiff_.data(); auto fn = [&](auto t) { - internals.last.interpolate(t, e.index, y_tmp.begin()); - return e.test(t, y_tmp.data()); + internals.last.interpolate(t, e.index, y_t); + return e.test(t, y_t); }; const auto f_t0 = fn(t0); const auto f_t1 = fn(t1); @@ -362,8 +360,6 @@ class solver { } if (idx_first < events.size()) { internals.last.interpolate(t1, y_next_.data()); - // These actions probably will have needed to bind - // shared/internal eventually, that will be done elsewhere. events[idx_first].action(t1, y_next_.data()); // We need to modify the history here so that search will find // the right point. From 234d91a07db01efc69f840271d204ba9d142ea97 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 11:09:12 +0000 Subject: [PATCH 13/19] Fix test --- tests/testthat/test-zzz-events.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-zzz-events.R b/tests/testthat/test-zzz-events.R index ae782a6b..f03c6263 100644 --- a/tests/testthat/test-zzz-events.R +++ b/tests/testthat/test-zzz-events.R @@ -15,7 +15,7 @@ test_that("can run system with roots and events", { ## Find all roots: r <- info$events[[1]] - expect_equal(nrow(r, 3)) + expect_equal(nrow(r), 3) expect_equal(r$time, cmp$roots, tolerance = 1e-6) expect_equal(r$index, rep(0L, 3)) expect_equal(r$sign, rep(-1, 3)) From 4eeb4339c44b2ba68e4355cd254814187e2d6437 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 11:11:03 +0000 Subject: [PATCH 14/19] Pass sign into event function too --- inst/include/dust2/continuous/events.hpp | 2 +- inst/include/dust2/continuous/solver.hpp | 2 +- tests/testthat/examples/event.cpp | 5 +---- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/inst/include/dust2/continuous/events.hpp b/inst/include/dust2/continuous/events.hpp index 0ab4c954..3a13ed85 100644 --- a/inst/include/dust2/continuous/events.hpp +++ b/inst/include/dust2/continuous/events.hpp @@ -31,7 +31,7 @@ bool is_root(const real_type a, const real_type b, const root_type& root) { template struct event { using test_type = std::function; - using action_type = std::function; + using action_type = std::function; std::vector index; root_type root = root_type::both; test_type test; diff --git a/inst/include/dust2/continuous/solver.hpp b/inst/include/dust2/continuous/solver.hpp index c60faaf6..07e9d460 100644 --- a/inst/include/dust2/continuous/solver.hpp +++ b/inst/include/dust2/continuous/solver.hpp @@ -360,7 +360,7 @@ class solver { } if (idx_first < events.size()) { internals.last.interpolate(t1, y_next_.data()); - events[idx_first].action(t1, y_next_.data()); + events[idx_first].action(t1, sign, y_next_.data()); // We need to modify the history here so that search will find // the right point. internals.last.t1 = t1; diff --git a/tests/testthat/examples/event.cpp b/tests/testthat/examples/event.cpp index 8f0fd76c..378524c2 100644 --- a/tests/testthat/examples/event.cpp +++ b/tests/testthat/examples/event.cpp @@ -56,10 +56,7 @@ class bounce { } static auto events(const shared_state& shared, internal_state& internal) { - // We can capture 'shared' by scope here, but not internal; that - // would require a second phase of binding, which would need to be - // done by the system. - auto action = [&](double t, double* y) { + auto action = [&](const double t, const double sign, double* y) { y[0] = 0; y[1] = -shared.damp * y[1]; }; From 4abb7db46b9beb62d9f10f8dc196f244a5d718e5 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 11:12:56 +0000 Subject: [PATCH 15/19] Exclude from diff --- .gitattributes | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitattributes b/.gitattributes index bc179e68..886f2359 100644 --- a/.gitattributes +++ b/.gitattributes @@ -7,4 +7,4 @@ src/sir.cpp linguist-generated=true src/sirode.cpp linguist-generated=true src/walk.cpp linguist-generated=true R/import-*.R linguist-generated=true -inst/include/lostturnip.hpp linguist-vendored=true +inst/include/lostturnip.hpp linguist-vendored=true linguist-generated=true From 64fb8c7880801940de42a119e2ba6a0c8518f6f6 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 11:16:52 +0000 Subject: [PATCH 16/19] Tidy up diff --- inst/include/dust2/continuous/history.hpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/inst/include/dust2/continuous/history.hpp b/inst/include/dust2/continuous/history.hpp index 5402754e..de21f4e1 100644 --- a/inst/include/dust2/continuous/history.hpp +++ b/inst/include/dust2/continuous/history.hpp @@ -66,15 +66,6 @@ struct history_step { } } - real_type interpolate(real_type time, size_t i) const { - // Consider special case for u or v == 0 - // u == 0: return c1[i] - // v == 0: return c1[i] + c2[i] - const auto u = (time - t0) / h; - const auto v = 1 - u; - return c1[i] + u * (c2[i] + v * (c3[i] + u * (c4[i] + v * c5[i]))); - } - history_step subset(std::vector index) const { return history_step(t0, t1, From 92f0e5894db95f32c2d01b85bd37663dcdc10776 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 11:17:02 +0000 Subject: [PATCH 17/19] Fix index --- inst/include/dust2/continuous/system.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inst/include/dust2/continuous/system.hpp b/inst/include/dust2/continuous/system.hpp index 59e4ec13..2120a01c 100644 --- a/inst/include/dust2/continuous/system.hpp +++ b/inst/include/dust2/continuous/system.hpp @@ -91,7 +91,7 @@ class dust_continuous { const auto offset = k * n_state_; real_type * y = state_data + offset; try { - solver_[i].run(time_, time, y, zero_every_[group], events_[group], + solver_[i].run(time_, time, y, zero_every_[group], events_[i], ode_internals_[k], rhs_(particle, group, thread)); } catch (std::exception const& e) { @@ -132,7 +132,7 @@ class dust_continuous { for (size_t step = 0; step < n_steps; ++step) { const real_type t0 = t1; t1 = (step == n_steps - 1) ? time : time_ + step * dt_; - solver_[i].run(t0, t1, y, zero_every_[group], events_[group], + solver_[i].run(t0, t1, y, zero_every_[group], events_[i], ode_internals_[k], rhs_(particle, group, thread)); std::copy_n(y, n_state_ode_, y_other); From 494b5d29d8b8e0fb2f259f996aff00bb2aec3bd8 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 26 Nov 2024 16:14:22 +0000 Subject: [PATCH 18/19] Bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index b205b64a..5976dfcb 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: dust2 Title: Next Generation dust -Version: 0.3.10 +Version: 0.3.11 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Imperial College of Science, Technology and Medicine", From da2af0c73d98bef9c3b691a4057057253d748a6a Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 28 Nov 2024 12:44:48 +0000 Subject: [PATCH 19/19] Apply suggestions from code review --- inst/include/dust2/r/continuous/system.hpp | 2 +- tests/testthat/test-zzz-events.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/inst/include/dust2/r/continuous/system.hpp b/inst/include/dust2/r/continuous/system.hpp index 01bbd6a1..9ff53981 100644 --- a/inst/include/dust2/r/continuous/system.hpp +++ b/inst/include/dust2/r/continuous/system.hpp @@ -116,7 +116,7 @@ cpp11::sexp ode_internals_to_sexp(const ode::internals& internals, auto r_event_sign = cpp11::writable::doubles(n_events); for (size_t i = 0; i < n_events; ++i) { r_event_time[i] = internals.events[i].time; - r_event_index[i] = static_cast(internals.events[i].index); + r_event_index[i] = static_cast(internals.events[i].index) + 1; r_event_sign[i] = internals.events[i].sign; } auto r_events = cpp11::writable::list{"time"_nm = std::move(r_event_time), diff --git a/tests/testthat/test-zzz-events.R b/tests/testthat/test-zzz-events.R index f03c6263..393f3dd3 100644 --- a/tests/testthat/test-zzz-events.R +++ b/tests/testthat/test-zzz-events.R @@ -17,7 +17,7 @@ test_that("can run system with roots and events", { r <- info$events[[1]] expect_equal(nrow(r), 3) expect_equal(r$time, cmp$roots, tolerance = 1e-6) - expect_equal(r$index, rep(0L, 3)) + expect_equal(r$index, rep(1L, 3)) expect_equal(r$sign, rep(-1, 3)) ## Stop at all roots: