From 047550df37f05ba711bd1da7f008fe710b54f9f2 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Tue, 17 Sep 2024 11:28:38 +0200 Subject: [PATCH] Create less intermediate data from events (#2249) This PR optimizes event handling by removing the intermediate structure created by this chain of transformations: 1. `event_lanes := vector>` (one lane per local cell, each lane sorted by time) 2. `staged_events_per_mech_id := vector>>` (one vector per mech id and time step, sorted by time) 3. `vector` one stream per mech id The following optimisations where performed: - cut out the middle step (2) completely as it is wholly unneeded and sort directly into event streams - remove a spurious index structure from `cable_cell_group` - slim down `deliverable_event` and `deliverable_event_data` - `event_stream` now uses a partition instead of a vector of ranges for splitting its data into `dt` buckets. (Save 8B per `dt` ;)) The result is that the quite pathological example `calcium_stdp.py` (as given in which generates immense amounts of spikes using a single cell group and a single epoch drops from 3.8GB heap to 1.9GB heap usage at peak using the same runtime or slightly less (<5% difference). ## TODO - [x] Port to GPU. - [X] Tests pass - [x] Examples run through - [x] Fix tests... Interestingly _locally_ all tests pass on my dev machine regardless of optimisation, vectorisation, and assertion settings. --------- Co-authored-by: boeschf <48126478+boeschf@users.noreply.github.com> --- arbor/backends/event.hpp | 8 +- arbor/backends/event_stream_base.hpp | 88 ++++++++--- arbor/backends/gpu/event_stream.hpp | 142 ++++++++---------- arbor/backends/gpu/shared_state.cpp | 13 +- arbor/backends/gpu/shared_state.cu | 20 +-- arbor/backends/gpu/shared_state.hpp | 13 +- arbor/backends/multicore/event_stream.hpp | 43 +++--- arbor/backends/multicore/fvm.hpp | 5 - arbor/backends/multicore/multicore_common.hpp | 1 - arbor/backends/multicore/shared_state.cpp | 18 ++- arbor/backends/multicore/shared_state.hpp | 12 +- arbor/backends/shared_state_base.hpp | 30 ++-- arbor/cable_cell_group.cpp | 55 +------ arbor/cable_cell_group.hpp | 28 +--- arbor/cell_group.hpp | 6 +- arbor/event_lane.hpp | 11 ++ arbor/fvm_lowered_cell.hpp | 18 +-- arbor/fvm_lowered_cell_impl.hpp | 43 +++--- test/unit/test_fvm_layout.cpp | 2 +- test/unit/test_fvm_lowered.cpp | 30 ++-- test/unit/test_probe.cpp | 29 ++-- test/unit/test_synapses.cpp | 13 +- 22 files changed, 301 insertions(+), 327 deletions(-) create mode 100644 arbor/event_lane.hpp diff --git a/arbor/backends/event.hpp b/arbor/backends/event.hpp index c379d780f7..81df1b38ed 100644 --- a/arbor/backends/event.hpp +++ b/arbor/backends/event.hpp @@ -51,10 +51,14 @@ struct has_event_index : public std::true_type {}; // Subset of event information required for mechanism delivery. struct deliverable_event_data { - cell_local_size_type mech_id; // same as target_handle::mech_id cell_local_size_type mech_index; // same as target_handle::mech_index float weight; - ARB_SERDES_ENABLE(deliverable_event_data, mech_id, mech_index, weight); + deliverable_event_data(cell_local_size_type idx, float w): + mech_index(idx), + weight(w) {} + ARB_SERDES_ENABLE(deliverable_event_data, + mech_index, + weight); }; // Stream index accessor function for multi_event_stream: diff --git a/arbor/backends/event_stream_base.hpp b/arbor/backends/event_stream_base.hpp index 1f7d67005d..86610f1608 100644 --- a/arbor/backends/event_stream_base.hpp +++ b/arbor/backends/event_stream_base.hpp @@ -1,63 +1,107 @@ #pragma once -#include #include #include +#include + #include "backends/event.hpp" #include "backends/event_stream_state.hpp" +#include "event_lane.hpp" +#include "timestep_range.hpp" +#include "util/partition.hpp" + +ARB_SERDES_ENABLE_EXT(arb_deliverable_event_data, mech_index, weight); namespace arb { -template -class event_stream_base { -public: // member types +template +struct event_stream_base { using size_type = std::size_t; using event_type = Event; using event_time_type = ::arb::event_time_type; using event_data_type = ::arb::event_data_type; -protected: // private member types - using span_type = Span; - - static_assert(std::is_same().begin()), event_data_type*>::value); - static_assert(std::is_same().end()), event_data_type*>::value); - protected: // members std::vector ev_data_; - std::vector ev_spans_; + std::vector ev_spans_ = {0}; size_type index_ = 0; + event_data_type* base_ptr_ = nullptr; public: event_stream_base() = default; // returns true if the currently marked time step has no events bool empty() const { - return ev_spans_.empty() || ev_data_.empty() || !index_ || index_ > ev_spans_.size() || - !ev_spans_[index_-1].size(); + return ev_data_.empty() // No events + || index_ < 1 // Since we index with a left bias, index_ must be at least 1 + || index_ >= ev_spans_.size() // Cannot index at container length + || ev_spans_[index_-1] >= ev_spans_[index_]; // Current span is empty } - void mark() { - index_ += (index_ <= ev_spans_.size() ? 1 : 0); - } + void mark() { index_ += 1; } auto marked_events() { - using std::begin; - using std::end; - if (empty()) { - return make_event_stream_state((event_data_type*)nullptr, (event_data_type*)nullptr); - } else { - return make_event_stream_state(begin(ev_spans_[index_-1]), end(ev_spans_[index_-1])); + auto beg = (event_data_type*)nullptr; + auto end = (event_data_type*)nullptr; + if (!empty()) { + beg = base_ptr_ + ev_spans_[index_-1]; + end = base_ptr_ + ev_spans_[index_]; } + return make_event_stream_state(beg, end); } // clear all previous data void clear() { ev_data_.clear(); + // Clear + push doesn't allocate a new vector ev_spans_.clear(); + ev_spans_.push_back(0); + base_ptr_ = nullptr; index_ = 0; } + + // Construct a mapping of mech_id to a stream s.t. streams are partitioned into + // time step buckets by `ev_span` + template + static std::enable_if_t> + multi_event_stream(const event_lane_subrange& lanes, + const std::vector& handles, + const std::vector& divs, + const timestep_range& steps, + std::unordered_map& streams) { + auto n_steps = steps.size(); + + std::unordered_map> dt_sizes; + for (auto& [k, v]: streams) { + v.clear(); + dt_sizes[k].resize(n_steps, 0); + } + + auto cell = 0; + for (auto& lane: lanes) { + auto div = divs[cell]; + arb_size_type step = 0; + for (auto evt: lane) { + auto time = evt.time; + auto weight = evt.weight; + auto target = evt.target; + while(step < n_steps && time >= steps[step].t_end()) ++step; + // Events coinciding with epoch's upper boundary belong to next epoch + if (step >= n_steps) break; + auto& handle = handles[div + target]; + streams[handle.mech_id].ev_data_.push_back({handle.mech_index, weight}); + dt_sizes[handle.mech_id][step]++; + } + ++cell; + } + + for (auto& [id, stream]: streams) { + util::make_partition(stream.ev_spans_, dt_sizes[id]); + stream.init(); + } + } }; } // namespace arb diff --git a/arbor/backends/gpu/event_stream.hpp b/arbor/backends/gpu/event_stream.hpp index 5377d26893..0045d93813 100644 --- a/arbor/backends/gpu/event_stream.hpp +++ b/arbor/backends/gpu/event_stream.hpp @@ -2,44 +2,33 @@ // Indexed collection of pop-only event queues --- CUDA back-end implementation. +#include + #include "backends/event_stream_base.hpp" -#include "memory/memory.hpp" -#include "util/partition.hpp" -#include "util/range.hpp" -#include "util/rangeutil.hpp" #include "util/transform.hpp" #include "threading/threading.hpp" -#include - -ARB_SERDES_ENABLE_EXT(arb_deliverable_event_data, mech_index, weight); +#include "timestep_range.hpp" +#include "memory/memory.hpp" namespace arb { namespace gpu { template -class event_stream : - public event_stream_base>::view_type> { +struct event_stream: public event_stream_base { public: - using base = event_stream_base>::view_type>; + using base = event_stream_base; using size_type = typename base::size_type; using event_data_type = typename base::event_data_type; using device_array = memory::device_vector; -private: // members - task_system_handle thread_pool_; - device_array device_ev_data_; - std::vector offsets_; + using base::clear; + using base::ev_data_; + using base::ev_spans_; + using base::base_ptr_; -public: event_stream() = default; event_stream(task_system_handle t): base(), thread_pool_{t} {} - void clear() { - base::clear(); - offsets_.clear(); - } - // Initialize event streams from a vector of vector of events // Outer vector represents time step bins void init(const std::vector>& staged) { @@ -54,30 +43,31 @@ class event_stream : if (!num_events) return; // allocate space for spans and data - base::ev_spans_.resize(staged.size()); - base::ev_data_.resize(num_events); - offsets_.resize(staged.size()+1); + ev_spans_.resize(staged.size() + 1); + ev_data_.resize(num_events); resize(device_ev_data_, num_events); // compute offsets by exclusive scan over staged events - util::make_partition(offsets_, - util::transform_view(staged, [&](const auto& v) { return v.size(); }), - (size_type)0u); + util::make_partition(ev_spans_, + util::transform_view(staged, [](const auto& v) { return v.size(); }), + 0ull); // assign, copy to device (and potentially sort) the event data in parallel arb_assert(thread_pool_); - threading::parallel_for::apply(0, staged.size(), thread_pool_.get(), - [this,&staged](size_type i) { - const auto offset = offsets_[i]; - const auto size = staged[i].size(); - // add device range - base::ev_spans_[i] = device_ev_data_(offset, offset + size); - // host span - auto host_span = memory::make_view(base::ev_data_)(offset, offset + size); + arb_assert(ev_spans_.size() == staged.size() + 1); + threading::parallel_for::apply(0, ev_spans_.size() - 1, thread_pool_.get(), + [this, &staged](size_type i) { + const auto beg = ev_spans_[i]; + const auto end = ev_spans_[i + 1]; + arb_assert(end >= beg); + const auto len = end - beg; + + auto host_span = memory::make_view(ev_data_)(beg, end); + // make event data and copy std::copy_n(util::transform_view(staged[i], [](const auto& x) { return event_data(x); }).begin(), - size, + len, host_span.begin()); // sort if necessary if constexpr (has_event_index::value) { @@ -85,56 +75,41 @@ class event_stream : [](const event_data_type& ed) { return event_index(ed); }); } // copy to device - memory::copy_async(host_span, base::ev_spans_[i]); + auto device_span = memory::make_view(device_ev_data_)(beg, end); + memory::copy_async(host_span, device_span); }); - arb_assert(num_events == base::ev_data_.size()); - } + base_ptr_ = device_ev_data_.data(); - friend void serialize(serializer& ser, const std::string& k, const event_stream& t) { - ser.begin_write_map(::arb::to_serdes_key(k)); - ARB_SERDES_WRITE(ev_data_); - ser.begin_write_map("ev_spans_"); - auto base_ptr = t.device_ev_data_.data(); - for (size_t ix = 0; ix < t.ev_spans_.size(); ++ix) { - ser.begin_write_map(std::to_string(ix)); - const auto& span = t.ev_spans_[ix]; - ser.write("offset", static_cast(span.begin() - base_ptr)); - ser.write("size", static_cast(span.size())); - ser.end_write_map(); - } - ser.end_write_map(); - ARB_SERDES_WRITE(index_); - ARB_SERDES_WRITE(device_ev_data_); - ARB_SERDES_WRITE(offsets_); - ser.end_write_map(); + arb_assert(num_events == device_ev_data_.size()); + arb_assert(num_events == ev_data_.size()); } - friend void deserialize(serializer& ser, const std::string& k, event_stream& t) { - ser.begin_read_map(::arb::to_serdes_key(k)); - ARB_SERDES_READ(ev_data_); - ser.begin_read_map("ev_spans_"); - for (size_t ix = 0; ser.next_key(); ++ix) { - ser.begin_read_map(std::to_string(ix)); - unsigned long long offset = 0, size = 0; - ser.read("offset", offset); - ser.read("size", size); - typename base::span_type span{t.ev_data_.data() + offset, size}; - if (ix < t.ev_spans_.size()) { - t.ev_spans_[ix] = span; - } else { - t.ev_spans_.emplace_back(span); - } - ser.end_read_map(); - } - ser.end_read_map(); - ARB_SERDES_READ(index_); - ARB_SERDES_READ(device_ev_data_); - ARB_SERDES_READ(offsets_); - ser.end_read_map(); + // Initialize event stream assuming ev_data_ and ev_span_ has + // been set previously (e.g. by `base::multi_event_stream`) + void init() { + resize(device_ev_data_, ev_data_.size()); + base_ptr_ = device_ev_data_.data(); + + threading::parallel_for::apply(0, ev_spans_.size() - 1, thread_pool_.get(), + [this](size_type i) { + const auto beg = ev_spans_[i]; + const auto end = ev_spans_[i + 1]; + arb_assert(end >= beg); + + auto host_span = memory::make_view(ev_data_)(beg, end); + auto device_span = memory::make_view(device_ev_data_)(beg, end); + + // sort if necessary + if constexpr (has_event_index::value) { + util::stable_sort_by(host_span, + [](const event_data_type& ed) { return event_index(ed); }); + } + // copy to device + memory::copy_async(host_span, device_span); + }); } -private: template static void resize(D& d, std::size_t size) { // resize if necessary @@ -142,6 +117,15 @@ class event_stream : d = D(size); } } + + ARB_SERDES_ENABLE(event_stream, + ev_data_, + ev_spans_, + device_ev_data_, + index_); + + task_system_handle thread_pool_; + device_array device_ev_data_; }; } // namespace gpu diff --git a/arbor/backends/gpu/shared_state.cpp b/arbor/backends/gpu/shared_state.cpp index a15a3f9545..928c938ba1 100644 --- a/arbor/backends/gpu/shared_state.cpp +++ b/arbor/backends/gpu/shared_state.cpp @@ -13,12 +13,10 @@ #include "backends/event_stream_state.hpp" #include "backends/gpu/chunk_writer.hpp" #include "memory/copy.hpp" -#include "memory/gpu_wrappers.hpp" #include "memory/wrappers.hpp" #include "util/index_into.hpp" #include "util/rangeutil.hpp" #include "util/maputil.hpp" -#include "util/meta.hpp" #include "util/range.hpp" #include "util/strprintf.hpp" @@ -241,7 +239,8 @@ void shared_state::instantiate(mechanism& m, m.ppack_.n_detectors = n_detector; if (storage.count(id)) throw arb::arbor_internal_error("Duplicate mech id in shared state"); - auto& store = storage.emplace(id, mech_storage{thread_pool}).first->second; + auto& store = storage.emplace(id, mech_storage{}).first->second; + streams[id] = deliverable_event_stream{thread_pool}; // Allocate view pointers store.state_vars_ = std::vector(m.mech_.n_state_vars); @@ -389,6 +388,14 @@ void shared_state::take_samples() { } } +void shared_state::init_events(const event_lane_subrange& lanes, + const std::vector& handles, + const std::vector& divs, + const timestep_range& dts) { + arb::gpu::event_stream::multi_event_stream(lanes, handles, divs, dts, streams); +} + + // Debug interface ARB_ARBOR_API std::ostream& operator<<(std::ostream& o, shared_state& s) { using io::csv; diff --git a/arbor/backends/gpu/shared_state.cu b/arbor/backends/gpu/shared_state.cu index 403067d70f..db22c03b0e 100644 --- a/arbor/backends/gpu/shared_state.cu +++ b/arbor/backends/gpu/shared_state.cu @@ -1,7 +1,5 @@ // GPU kernels and wrappers for shared state methods. -#include - #include #include @@ -24,13 +22,11 @@ __global__ void add_scalar(unsigned n, } } -__global__ void take_samples_impl( - const raw_probe_info* __restrict__ const begin_marked, - const raw_probe_info* __restrict__ const end_marked, - const arb_value_type time, - arb_value_type* __restrict__ const sample_time, - arb_value_type* __restrict__ const sample_value) -{ +__global__ void take_samples_impl(const raw_probe_info* __restrict__ const begin_marked, + const raw_probe_info* __restrict__ const end_marked, + const arb_value_type time, + arb_value_type* __restrict__ const sample_time, + arb_value_type* __restrict__ const sample_value) { const unsigned i = threadIdx.x+blockIdx.x*blockDim.x; const unsigned nsamples = end_marked - begin_marked; if (i, n, data, v); } -void take_samples_impl( - const event_stream_state& s, - const arb_value_type& time, arb_value_type* sample_time, arb_value_type* sample_value) -{ +void take_samples_impl(const event_stream_state& s, + const arb_value_type& time, arb_value_type* sample_time, arb_value_type* sample_value) { launch_1d(s.size(), 128, kernel::take_samples_impl, s.begin_marked, s.end_marked, time, sample_time, sample_value); } diff --git a/arbor/backends/gpu/shared_state.hpp b/arbor/backends/gpu/shared_state.hpp index 31656c99be..030cbcdab5 100644 --- a/arbor/backends/gpu/shared_state.hpp +++ b/arbor/backends/gpu/shared_state.hpp @@ -115,8 +115,6 @@ struct ARB_ARBOR_API istim_state { }; struct mech_storage { - mech_storage() = default; - mech_storage(task_system_handle tp) : deliverable_events_(tp) {} array data_; iarray indices_; std::vector globals_; @@ -127,7 +125,6 @@ struct mech_storage { memory::device_vector state_vars_d_; memory::device_vector ion_states_d_; random_numbers random_numbers_; - deliverable_event_stream deliverable_events_; }; struct ARB_ARBOR_API shared_state: shared_state_base { @@ -172,6 +169,7 @@ struct ARB_ARBOR_API shared_state: shared_state_base ion_data; std::unordered_map storage; + std::unordered_map streams; shared_state() = default; @@ -242,6 +240,11 @@ struct ARB_ARBOR_API shared_state: shared_state_base& handles, + const std::vector& divs, + const timestep_range& dts); }; // For debugging only @@ -253,12 +256,12 @@ ARB_SERDES_ENABLE_EXT(gpu::ion_state, Xd_, gX_); ARB_SERDES_ENABLE_EXT(gpu::mech_storage, data_, // NOTE(serdes) ion_states_, this is just a bunch of pointers - random_numbers_, - deliverable_events_); + random_numbers_); ARB_SERDES_ENABLE_EXT(gpu::shared_state, cbprng_seed, ion_data, storage, + streams, voltage, current_density, conductivity, diff --git a/arbor/backends/multicore/event_stream.hpp b/arbor/backends/multicore/event_stream.hpp index b03e2e3181..f280777db9 100644 --- a/arbor/backends/multicore/event_stream.hpp +++ b/arbor/backends/multicore/event_stream.hpp @@ -2,24 +2,26 @@ // Indexed collection of pop-only event queues --- multicore back-end implementation. +#include "arbor/spike_event.hpp" #include "backends/event_stream_base.hpp" -#include "util/range.hpp" -#include "util/rangeutil.hpp" +#include "timestep_range.hpp" namespace arb { namespace multicore { template -class event_stream : public event_stream_base*>> { -public: - using base = event_stream_base*>>; +struct event_stream: public event_stream_base { + using base = event_stream_base; using size_type = typename base::size_type; - event_stream() = default; - using base::clear; + using base::ev_spans_; + using base::ev_data_; + using base::base_ptr_; + + event_stream() = default; - // Initialize event streams from a vector of vector of events + // Initialize event stream from a vector of vector of events // Outer vector represents time step bins void init(const std::vector>& staged) { // clear previous data @@ -33,23 +35,28 @@ class event_stream : public event_stream_base, ev_data_, ev_spans_, index_); -}; + // Initialize event stream assuming ev_data_ and ev_span_ has + // been set previously (e.g. by `base::multi_event_stream`) + void init() { base_ptr_ = ev_data_.data(); } + ARB_SERDES_ENABLE(event_stream, + ev_data_, + ev_spans_, + index_); +}; } // namespace multicore } // namespace arb diff --git a/arbor/backends/multicore/fvm.hpp b/arbor/backends/multicore/fvm.hpp index 1413a70b52..be787cbcc4 100644 --- a/arbor/backends/multicore/fvm.hpp +++ b/arbor/backends/multicore/fvm.hpp @@ -1,19 +1,14 @@ #pragma once #include -#include #include -#include "backends/event.hpp" -#include "backends/multicore/event_stream.hpp" #include "backends/multicore/multicore_common.hpp" #include "backends/multicore/shared_state.hpp" #include "backends/multicore/diffusion_solver.hpp" #include "backends/multicore/cable_solver.hpp" #include "backends/multicore/threshold_watcher.hpp" -#include "execution_context.hpp" -#include "util/padded_alloc.hpp" #include "util/range.hpp" #include "util/rangeutil.hpp" diff --git a/arbor/backends/multicore/multicore_common.hpp b/arbor/backends/multicore/multicore_common.hpp index 6c88701f37..9186628eb0 100644 --- a/arbor/backends/multicore/multicore_common.hpp +++ b/arbor/backends/multicore/multicore_common.hpp @@ -5,7 +5,6 @@ // // Defines array, iarray, and specialized multi-event stream classes. -#include #include #include diff --git a/arbor/backends/multicore/shared_state.cpp b/arbor/backends/multicore/shared_state.cpp index a32ea19b0f..17a57789fe 100644 --- a/arbor/backends/multicore/shared_state.cpp +++ b/arbor/backends/multicore/shared_state.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -15,8 +14,6 @@ #include #include -#include "backends/event.hpp" -#include "backends/rand_impl.hpp" #include "io/sepval.hpp" #include "util/index_into.hpp" #include "util/padded_alloc.hpp" @@ -27,11 +24,11 @@ #include "multicore_common.hpp" #include "shared_state.hpp" +#include "fvm.hpp" namespace arb { namespace multicore { -using util::make_range; using util::make_span; using util::ptr_by_key; using util::value_by_key; @@ -384,9 +381,8 @@ void shared_state::instantiate(arb::mechanism& m, util::padded_allocator<> pad(m.data_alignment()); - if (storage.find(id) != storage.end()) { - throw arbor_internal_error("Duplicate mechanism id in MC shared state."); - } + if (storage.count(id)) throw arbor_internal_error("Duplicate mechanism id in MC shared state."); + streams[id] = deliverable_event_stream{}; auto& store = storage[id]; auto width = pos_data.cv.size(); // Assign non-owning views onto shared state: @@ -539,5 +535,13 @@ void shared_state::instantiate(arb::mechanism& m, } } +void shared_state::init_events(const event_lane_subrange& lanes, + const std::vector& handles, + const std::vector& divs, + const timestep_range& dts) { + arb::multicore::event_stream::multi_event_stream(lanes, handles, divs, dts, streams); +} + + } // namespace multicore } // namespace arb diff --git a/arbor/backends/multicore/shared_state.hpp b/arbor/backends/multicore/shared_state.hpp index 3a4a97a9e7..99bbe37b51 100644 --- a/arbor/backends/multicore/shared_state.hpp +++ b/arbor/backends/multicore/shared_state.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -15,14 +14,12 @@ #include #include "fvm_layout.hpp" -#include "timestep_range.hpp" #include "util/padded_alloc.hpp" #include "util/rangeutil.hpp" #include "threading/threading.hpp" -#include "backends/event.hpp" #include "backends/common_types.hpp" #include "backends/rand_fwd.hpp" #include "backends/shared_state_base.hpp" @@ -103,8 +100,6 @@ struct mech_storage { std::vector gid_; std::vector idx_; cbprng::counter_type random_number_update_counter_ = 0u; - - deliverable_event_stream deliverable_events_; }; struct ARB_ARBOR_API istim_state { @@ -180,6 +175,7 @@ struct ARB_ARBOR_API shared_state: istim_state stim_data; std::unordered_map ion_data; std::unordered_map storage; + std::unordered_map streams; shared_state() = default; @@ -249,6 +245,11 @@ struct ARB_ARBOR_API shared_state: sample_time_host = util::range_pointer_view(sample_time); sample_value_host = util::range_pointer_view(sample_value); } + + void init_events(const event_lane_subrange& lanes, + const std::vector& handles, + const std::vector& divs, + const timestep_range& dts); }; // For debugging only: @@ -266,6 +267,7 @@ ARB_SERDES_ENABLE_EXT(multicore::shared_state, cbprng_seed, ion_data, storage, + streams, voltage, conductivity, time_since_spike, diff --git a/arbor/backends/shared_state_base.hpp b/arbor/backends/shared_state_base.hpp index 6a8a37ca4c..c20247d363 100644 --- a/arbor/backends/shared_state_base.hpp +++ b/arbor/backends/shared_state_base.hpp @@ -7,7 +7,8 @@ #include "backends/common_types.hpp" #include "fvm_layout.hpp" -#include "util/rangeutil.hpp" +#include "event_lane.hpp" +#include "timestep_range.hpp" namespace arb { @@ -26,18 +27,14 @@ struct shared_state_base { d->time = d->time_to; } - void begin_epoch(const std::vector>>& staged_events_per_mech_id, + void begin_epoch(const event_lane_subrange& lanes, const std::vector>& samples, - const timestep_range& dts) { + const timestep_range& dts, + const std::vector& handles, + const std::vector& divs) { auto d = static_cast(this); // events - auto& storage = d->storage; - for (auto& [mech_id, store] : storage) { - if (mech_id < staged_events_per_mech_id.size() && staged_events_per_mech_id[mech_id].size()) - { - store.deliverable_events_.init(staged_events_per_mech_id[mech_id]); - } - } + d->init_events(lanes, handles, divs, dts); // samples auto n_samples = util::sum_by(samples, [] (const auto& s) {return s.size();}); if (d->sample_time.size() < n_samples) { @@ -91,18 +88,15 @@ struct shared_state_base { void mark_events() { auto d = static_cast(this); - auto& storage = d->storage; - for (auto& s : storage) { - s.second.deliverable_events_.mark(); - } + auto& streams = d->streams; + for (auto& stream: streams) stream.second.mark(); } void deliver_events(mechanism& m) { auto d = static_cast(this); - auto& storage = d->storage; - if (auto it = storage.find(m.mechanism_id()); it != storage.end()) { - auto& deliverable_events = it->second.deliverable_events_; - if (!deliverable_events.empty()) { + auto& streams = d->streams; + if (auto it = streams.find(m.mechanism_id()); it != streams.end()) { + if (auto& deliverable_events = it->second; !deliverable_events.empty()) { auto state = deliverable_events.marked_events(); m.deliver_events(state); } diff --git a/arbor/cable_cell_group.cpp b/arbor/cable_cell_group.cpp index efa27ea0e8..f04ac941d6 100644 --- a/arbor/cable_cell_group.cpp +++ b/arbor/cable_cell_group.cpp @@ -28,32 +28,16 @@ cable_cell_group::cable_cell_group(const std::vector& gids, fvm_lowered_cell_ptr lowered): gids_(gids), lowered_(std::move(lowered)) { - // Build lookup table for gid to local index. - for (auto i: util::count_along(gids_)) { - gid_index_map_[gids_[i]] = i; - } // Construct cell implementation, retrieving handles and maps. auto fvm_info = lowered_->initialize(gids_, rec); - for (auto [mech_id, n_targets] : fvm_info.num_targets_per_mech_id) { - if (n_targets > 0u && mech_id >= staged_events_per_mech_id_.size()) { - staged_events_per_mech_id_.resize(mech_id+1); - } - } - // Propagate source and target ranges to the simulator object cg_sources = std::move(fvm_info.source_data); cg_targets = std::move(fvm_info.target_data); - // Store consistent data from fvm_lowered_cell - target_handles_ = std::move(fvm_info.target_handles); probe_map_ = std::move(fvm_info.probe_map); - // Create lookup structure for target ids. - util::make_partition(target_handle_divisions_, - util::transform_view(gids_, [&](cell_gid_type i) { return fvm_info.num_targets[i]; })); - // Create a list of the global identifiers for the spike sources for (auto source_gid: gids_) { for (cell_lid_type lid = 0; lidtime(); // Bin and collate deliverable events from event lanes. - - PE(advance:eventsetup:clear); // Split epoch into equally sized timesteps (last timestep is chosen to match end of epoch) timesteps_.reset(ep, dt); - for (auto& vv : staged_events_per_mech_id_) { - vv.resize(timesteps_.size()); - for (auto& v : vv) { - v.clear(); - } - } - sample_events_.resize(timesteps_.size()); - for (auto& v : sample_events_) { - v.clear(); - } - PL(); - // Skip event handling if nothing to deliver. - PE(advance:eventsetup:push); - if (util::sum_by(event_lanes, [] (const auto& l) {return l.size();})) { - auto lid = 0; - for (auto& lane: event_lanes) { - arb_size_type timestep_index = 0; - for (auto e: lane) { - // Events coinciding with epoch's upper boundary belong to next epoch - const auto time = e.time; - if (time >= ep.t1) break; - while(time >= timesteps_[timestep_index].t_end()) { - ++timestep_index; - } - arb_assert(timestep_index < timesteps_.size()); - const auto offset = target_handle_divisions_[lid]+e.target; - const auto h = target_handles_[offset]; - staged_events_per_mech_id_[h.mech_id][timestep_index].emplace_back(e.time, h, e.weight); - } - ++lid; - } - } + PE(advance:samplesetup:clear); + sample_events_.resize(timesteps_.size()); + for (auto& v: sample_events_) v.clear(); PL(); // Create sample events and delivery information. @@ -472,7 +425,7 @@ void cable_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange PL(); // Run integration and collect samples, spikes. - auto result = lowered_->integrate(timesteps_, staged_events_per_mech_id_, sample_events_); + auto result = lowered_->integrate(timesteps_, event_lanes, sample_events_); // For each sampler callback registered in `call_info`, construct the // vector of sample entries from the lowered cell sample times and values diff --git a/arbor/cable_cell_group.hpp b/arbor/cable_cell_group.hpp index 614b563a38..7403d82ef4 100644 --- a/arbor/cable_cell_group.hpp +++ b/arbor/cable_cell_group.hpp @@ -1,10 +1,6 @@ #pragma once -#include -#include -#include #include -#include #include #include @@ -26,14 +22,12 @@ namespace arb { struct ARB_ARBOR_API cable_cell_group: public cell_group { cable_cell_group() = default; cable_cell_group(const std::vector& gids, - const recipe& rec, - cell_label_range& cg_sources, - cell_label_range& cg_targets, - fvm_lowered_cell_ptr lowered); + const recipe& rec, + cell_label_range& cg_sources, + cell_label_range& cg_targets, + fvm_lowered_cell_ptr lowered); - cell_kind get_cell_kind() const override { - return cell_kind::cable; - } + cell_kind get_cell_kind() const override { return cell_kind::cable; } void reset() override; @@ -60,9 +54,6 @@ struct ARB_ARBOR_API cable_cell_group: public cell_group { // List of the gids of the cells in the group. std::vector gids_; - // Hash table for converting gid to local index - std::unordered_map gid_index_map_; - // The lowered cell state (e.g. FVM) of the cell. fvm_lowered_cell_ptr lowered_; @@ -75,15 +66,9 @@ struct ARB_ARBOR_API cable_cell_group: public cell_group { // Range of timesteps within current epoch timestep_range timesteps_; - // List of events to deliver per mechanism id - std::vector>> staged_events_per_mech_id_; - // List of samples to be taken std::vector> sample_events_; - // Handles for accessing lowered cell. - std::vector target_handles_; - // Maps probe ids to probe handles (from lowered cell) and tags (from probe descriptions). probe_association_map probe_map_; @@ -92,9 +77,6 @@ struct ARB_ARBOR_API cable_cell_group: public cell_group { // Mutex for thread-safe access to sampler associations. std::mutex sampler_mex_; - - // Lookup table for target ids -> local target handle indices. - std::vector target_handle_divisions_; }; } // namespace arb diff --git a/arbor/cell_group.hpp b/arbor/cell_group.hpp index d372f11877..f7689dd973 100644 --- a/arbor/cell_group.hpp +++ b/arbor/cell_group.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -8,11 +7,10 @@ #include #include #include -#include #include #include "epoch.hpp" -#include "util/rangeutil.hpp" +#include "event_lane.hpp" // The specialized cell_group constructors are expected to accept at least: // - The gid vector of the cells belonging to the cell_group. @@ -22,8 +20,6 @@ // ranges are needed to map (gid, label) pairs to their corresponding lid sets. namespace arb { -using event_lane_subrange = util::subrange_view_type>; - class cell_group { public: virtual ~cell_group() = default; diff --git a/arbor/event_lane.hpp b/arbor/event_lane.hpp new file mode 100644 index 0000000000..4aa8c9bb1d --- /dev/null +++ b/arbor/event_lane.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include + +#include "util/rangeutil.hpp" + +namespace arb { + +using event_lane_subrange = util::subrange_view_type>; + +} // namespace arb diff --git a/arbor/fvm_lowered_cell.hpp b/arbor/fvm_lowered_cell.hpp index e324661c1d..896380c019 100644 --- a/arbor/fvm_lowered_cell.hpp +++ b/arbor/fvm_lowered_cell.hpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -17,19 +16,14 @@ #include #include #include +#include #include "backends/event.hpp" #include "backends/common_types.hpp" -#include "backends/threshold_crossing.hpp" #include "execution_context.hpp" -#include "sampler_map.hpp" +#include "event_lane.hpp" #include "timestep_range.hpp" -#include "util/maputil.hpp" -#include "util/meta.hpp" #include "util/range.hpp" -#include "util/rangeutil.hpp" -#include "util/strprintf.hpp" -#include "util/transform.hpp" namespace arb { @@ -217,7 +211,6 @@ struct probe_association_map { struct fvm_initialization_data { // Handles for accessing lowered cell. - std::vector target_handles; std::unordered_map num_targets_per_mech_id; // Maps probe ids to probe handles and tags. @@ -242,10 +235,9 @@ struct fvm_lowered_cell { const std::vector& gids, const recipe& rec) = 0; - virtual fvm_integration_result integrate( - const timestep_range& dts, - const std::vector>>& staged_events_per_mech_id, - const std::vector>& staged_samples) = 0; + virtual fvm_integration_result integrate(const timestep_range& dts, + const event_lane_subrange& event_lanes, + const std::vector>& staged_samples) = 0; virtual arb_value_type time() const = 0; diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index 316300abbb..730471d109 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -20,6 +20,7 @@ #include "execution_context.hpp" #include "fvm_layout.hpp" + #include "fvm_lowered_cell.hpp" #include "label_resolution.hpp" #include "profile/profiler_macro.hpp" @@ -30,10 +31,8 @@ #include "util/transform.hpp" namespace arb { - template -class fvm_lowered_cell_impl: public fvm_lowered_cell { -public: +struct fvm_lowered_cell_impl: public fvm_lowered_cell { using backend = Backend; using value_type = arb_value_type; using index_type = arb_index_type; @@ -50,10 +49,9 @@ class fvm_lowered_cell_impl: public fvm_lowered_cell { const std::vector& gids, const recipe& rec) override; - fvm_integration_result integrate( - const timestep_range& dts, - const std::vector>>& staged_events_per_mech_id, - const std::vector>& staged_samples) override; + fvm_integration_result integrate(const timestep_range& dts, + const event_lane_subrange& event_lanes, + const std::vector>& staged_samples) override; value_type time() const override { return state_->time; } @@ -67,7 +65,6 @@ class fvm_lowered_cell_impl: public fvm_lowered_cell { void t_serialize(serializer& ser, const std::string& k) const override { serialize(ser, k, *this); } void t_deserialize(serializer& ser, const std::string& k) override { deserialize(ser, k, *this); } -private: // Host or GPU-side back-end dependent storage. using array = typename backend::array; using shared_state = typename backend::shared_state; @@ -80,6 +77,11 @@ class fvm_lowered_cell_impl: public fvm_lowered_cell { std::vector revpot_mechanisms_; std::vector voltage_mechanisms_; + // Handles for accessing event targets. + std::vector target_handles_; + // Lookup table for target ids -> local target handle indices. + std::vector target_handle_divisions_; + // Optional non-physical voltage check threshold std::optional check_voltage_mV_; @@ -164,18 +166,16 @@ void fvm_lowered_cell_impl::reset() { } template -fvm_integration_result fvm_lowered_cell_impl::integrate( - const timestep_range& dts, - const std::vector>>& staged_events_per_mech_id, - const std::vector>& staged_samples) -{ +fvm_integration_result fvm_lowered_cell_impl::integrate(const timestep_range& dts, + const event_lane_subrange& event_lanes, + const std::vector>& staged_samples) { arb_assert(state_->time == dts.t_begin()); set_gpu(); // Integration setup PE(advance:integrate:setup); // Push samples and events down to the state and reset the spike thresholds. - state_->begin_epoch(staged_events_per_mech_id, staged_samples, dts); + state_->begin_epoch(event_lanes, staged_samples, dts, target_handles_, target_handle_divisions_); PL(); // loop over timesteps @@ -477,10 +477,9 @@ fvm_lowered_cell_impl::initialize(const std::vector& gid data_alignment? data_alignment: 1u, seed_); - fvm_info.target_handles.resize(mech_data.n_target); - // Keep track of mechanisms by name for probe lookup. std::unordered_map mechptr_by_name; + target_handles_.resize(mech_data.n_target); unsigned mech_id = 0; for (const auto& [name, config]: mech_data.mechanisms) { @@ -516,11 +515,11 @@ fvm_lowered_cell_impl::initialize(const std::vector& gid target_handle handle(mech_id, i); if (config.multiplicity.empty()) { - fvm_info.target_handles[config.target[i]] = handle; + target_handles_[config.target[i]] = handle; } else { for (auto j: make_span(multiplicity_part[i])) { - fvm_info.target_handles[config.target[j]] = handle; + target_handles_[config.target[j]] = handle; } } } @@ -579,8 +578,14 @@ fvm_lowered_cell_impl::initialize(const std::vector& gid } } - add_probes(gids, cells, rec, D, mechptr_by_name, mech_data, fvm_info.target_handles, fvm_info.probe_map); + add_probes(gids, cells, rec, D, mechptr_by_name, mech_data, target_handles_, fvm_info.probe_map); + + // Create lookup structure for target ids. + util::make_partition(target_handle_divisions_, + util::transform_view(gids, + [&](cell_gid_type i) { return fvm_info.num_targets[i]; })); + reset(); return fvm_info; } diff --git a/test/unit/test_fvm_layout.cpp b/test/unit/test_fvm_layout.cpp index 956bc965d6..68f392b12a 100644 --- a/test/unit/test_fvm_layout.cpp +++ b/test/unit/test_fvm_layout.cpp @@ -40,7 +40,7 @@ using backend = arb::multicore::backend; using fvm_cell = arb::fvm_lowered_cell_impl; // instantiate template class -template class arb::fvm_lowered_cell_impl; +template struct arb::fvm_lowered_cell_impl; namespace U = arb::units; diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index ef83d361fa..383a1ffd5c 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -237,7 +237,7 @@ TEST(fvm_lowered, target_handles) { EXPECT_EQ(cells[0].morphology().num_branches(), 1u); EXPECT_EQ(cells[1].morphology().num_branches(), 3u); - auto test_target_handles = [&](fvm_cell& cell, const std::vector& targets) { + auto test_target_handles = [&](fvm_cell& cell) { mechanism* expsyn = find_mechanism(cell, "expsyn"); ASSERT_TRUE(expsyn); mechanism* exp2syn = find_mechanism(cell, "exp2syn"); @@ -246,28 +246,28 @@ TEST(fvm_lowered, target_handles) { unsigned expsyn_id = expsyn->mechanism_id(); unsigned exp2syn_id = exp2syn->mechanism_id(); - EXPECT_EQ(4u, targets.size()); + EXPECT_EQ(4u, cell.target_handles_.size()); - EXPECT_EQ(expsyn_id, targets[0].mech_id); - EXPECT_EQ(1u, targets[0].mech_index); + EXPECT_EQ(expsyn_id, cell.target_handles_[0].mech_id); + EXPECT_EQ(1u, cell.target_handles_[0].mech_index); - EXPECT_EQ(expsyn_id, targets[1].mech_id); - EXPECT_EQ(0u, targets[1].mech_index); + EXPECT_EQ(expsyn_id, cell.target_handles_[1].mech_id); + EXPECT_EQ(0u, cell.target_handles_[1].mech_index); - EXPECT_EQ(exp2syn_id, targets[2].mech_id); - EXPECT_EQ(0u, targets[2].mech_index); + EXPECT_EQ(exp2syn_id, cell.target_handles_[2].mech_id); + EXPECT_EQ(0u, cell.target_handles_[2].mech_index); - EXPECT_EQ(expsyn_id, targets[3].mech_id); - EXPECT_EQ(2u, targets[3].mech_index); + EXPECT_EQ(expsyn_id, cell.target_handles_[3].mech_id); + EXPECT_EQ(2u, cell.target_handles_[3].mech_index); }; fvm_cell fvcell0(*context); auto fvm_info0 = fvcell0.initialize({0, 1}, cable1d_recipe(cells, true)); - test_target_handles(fvcell0, fvm_info0.target_handles); + test_target_handles(fvcell0); fvm_cell fvcell1(*context); auto fvm_info1 = fvcell1.initialize({0, 1}, cable1d_recipe(cells, false)); - test_target_handles(fvcell1, fvm_info1.target_handles); + test_target_handles(fvcell1); } @@ -699,7 +699,9 @@ TEST(fvm_lowered, point_ionic_current) { // Only one target, corresponding to our point process on soma. double ica_nA = 12.3; - deliverable_event ev = {0.04, target_handle{0, 0}, (float)ica_nA}; + std::vector events{{{0, 0.04, (float)ica_nA}}}; + auto lanes = util::subrange_view(events, 0, events.size()); + auto& state = *(fvcell.*private_state_ptr).get(); auto& ion = state.ion_data.at("ca"s); @@ -709,7 +711,7 @@ TEST(fvm_lowered, point_ionic_current) { // Ionic current should be ica_nA/soma_area after integrating past event time. const double time = 0.5; // [ms] - (void)fvcell.integrate({time, 0.01}, {{{},{},{},{},{ev}}}, {}); + (void)fvcell.integrate({time, 0.01}, lanes, {}); double expected_iX = ica_nA*1e-9/soma_area_m2; EXPECT_FLOAT_EQ(expected_iX, ion.iX_[0]); diff --git a/test/unit/test_probe.cpp b/test/unit/test_probe.cpp index 6b82baecc7..2df8fce69b 100644 --- a/test/unit/test_probe.cpp +++ b/test/unit/test_probe.cpp @@ -281,7 +281,7 @@ void run_expsyn_g_probe_test(context ctx) { fvm_cell lcell(*ctx); auto fvm_info = lcell.initialize({0}, rec); const auto& probe_map = fvm_info.probe_map; - const auto& targets = fvm_info.target_handles; + const auto& targets = lcell.target_handles_; EXPECT_EQ(2u, rec.get_probes(0).size()); EXPECT_EQ(2u, probe_map.size()); @@ -309,14 +309,12 @@ void run_expsyn_g_probe_test(context ctx) { // and another at 2ms to second, weight 1. arb_assert(targets[0].mech_id == targets[1].mech_id); - std::vector>> events(targets[0].mech_id+1); const double tfinal = 3.0; const double dt = 0.001; const timestep_range dts{tfinal, dt}; - events[targets[0].mech_id].resize(dts.size()); - events[targets[0].mech_id][dts.find(1.0)-dts.begin()].push_back(deliverable_event{1.0, targets[0], 0.5}); - events[targets[0].mech_id][dts.find(2.0)-dts.begin()].push_back(deliverable_event{2.0, targets[1], 1.0}); - lcell.integrate(dts, events, {}); + std::vector events{{{0, 1.0, 0.5}, {1, 2.0, 1.0}}}; + auto lanes = util::subrange_view(events, 0, events.size()); + lcell.integrate(dts, lanes, {}); arb_value_type g0 = deref(p0); arb_value_type g1 = deref(p1); @@ -392,23 +390,20 @@ void run_expsyn_g_cell_probe_test(context ctx) { fvm_cell lcell(*ctx); auto fvm_info = lcell.initialize({0, 1}, rec); const auto& probe_map = fvm_info.probe_map; - const auto& targets = fvm_info.target_handles; + const auto& targets = lcell.target_handles_; // Send an event to each expsyn synapse with a weight = target+100*cell_gid, and // integrate for a tiny time step. - - std::vector>> events(2, std::vector>(1)); - for (unsigned i: {0u, 1u}) { - // Cells have the same number of targets, so the offset for cell 1 is exactly... - cell_local_size_type cell_offset = i==0? 0: targets.size()/2; - + std::vector events; + events.resize(targets.size()); + for (unsigned gid: {0u, 1u}) { for (auto target_id: util::keys(expsyn_target_loc_map)) { - auto h = targets.at(target_id+cell_offset); - deliverable_event ev{0., h, float(target_id+100*i)}; - events[h.mech_id][0].push_back(ev); + events[gid].emplace_back(target_id, 0., float(target_id+100*gid)); } } - (void)lcell.integrate({1e-5, 1e-5}, events, {}); + + auto lanes = util::subrange_view(events, 0, events.size()); + (void)lcell.integrate({1e-5, 1e-5}, lanes, {}); // Independently get cv geometry to compute CV indices. diff --git a/test/unit/test_synapses.cpp b/test/unit/test_synapses.cpp index 8e56151faa..ffe9ca40a5 100644 --- a/test/unit/test_synapses.cpp +++ b/test/unit/test_synapses.cpp @@ -149,10 +149,11 @@ TEST(synapses, syn_basic_state) { // Deliver two events (at time 0), one each to expsyn synapses 1 and 3 // and exp2syn synapses 0 and 2. - - state.begin_epoch({{{{0., {0, 1}, 3.14f}, {0., {0, 3}, 1.41f}}}, // events for mech_id == 0 - {{{0., {1, 0}, 2.71f}, {0., {1, 2}, 0.07f}}}}, // events for mech_id == 1 - {}, dts); + std::vector events{{{0, 0.0, 3.14f}, {1, 0.0, 1.41f}, {2, 0.0, 2.71f}, {3, 0.0, 0.07f}}}; + auto lanes = event_lane_subrange(events.begin(), events.end()); + std::vector handles{{0, 1}, {0, 3}, {1, 0}, {1, 2}}; + std::vector divs{0}; + state.begin_epoch(lanes, {}, dts, handles, divs); state.mark_events(); state.deliver_events(*expsyn); @@ -160,8 +161,8 @@ TEST(synapses, syn_basic_state) { using fvec = std::vector; - EXPECT_TRUE(testing::seq_almost_eq( - fvec({0, 3.14f, 0, 1.41f}), mechanism_field(expsyn, "g"))); + EXPECT_TRUE(testing::seq_almost_eq(fvec({0, 3.14f, 0, 1.41f}), + mechanism_field(expsyn, "g"))); double factor = mechanism_field(exp2syn, "factor")[0]; EXPECT_TRUE(factor>1.);