From 218d0b3cf10c38d8e5c08941bdcdce705aecdfe8 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Thu, 21 Nov 2024 07:56:41 +0100 Subject: [PATCH] Remove introsort, parallelise merging across mech_ids. --- arbor/backends/event_stream_base.hpp | 94 ++++++++++++++++++++-------- arbor/backends/shared_state_base.hpp | 5 +- arbor/fvm_lowered_cell.hpp | 7 +-- arbor/fvm_lowered_cell_impl.hpp | 8 +-- 4 files changed, 77 insertions(+), 37 deletions(-) diff --git a/arbor/backends/event_stream_base.hpp b/arbor/backends/event_stream_base.hpp index a81254a78..69bed4700 100644 --- a/arbor/backends/event_stream_base.hpp +++ b/arbor/backends/event_stream_base.hpp @@ -7,6 +7,7 @@ #include "backends/event.hpp" #include "backends/event_stream_state.hpp" #include "event_lane.hpp" +#include "threading/threading.hpp" #include "timestep_range.hpp" #include "util/partition.hpp" @@ -22,9 +23,11 @@ struct event_stream_base { protected: // members std::vector ev_data_; std::vector ev_spans_ = {0}; + std::vector lane_spans_; std::size_t index_ = 0; event_data_type* base_ptr_ = nullptr; + public: event_stream_base() = default; @@ -63,31 +66,61 @@ struct event_stream_base { virtual void init() = 0; }; -struct spike_event_stream_base : event_stream_base { +struct spike_event_stream_base: event_stream_base { + // Take in one event lane per cell `gid` and reorganise into one stream per + // synapse `mech_id`. + // + // - Due to the cell group coalescing multiple cells and their synapses into + // one object, one `mech_id` can touch multiple lanes / `gid`s. + // - Inversely, two `mech_id`s can cover different, but overlapping sets of `gid`s + // - Multiple `mech_id`s can receive events from the same source + // + // Pre: + // - Events in `lanes[ix]` forall ix + // * are sorted by time + // * `ix` maps to exactly one cell in the local cell group + // - `divs` partitions `handles` such that the target handles for cell `ix` + // are located in `handles[divs[ix]..divs[ix + 1]]` + // - `handles` records `(mech_id, index)` of a target s.t. `index` is the instance + // with the set identified by `mech_id`, e.g. a single synapse placed on a multi- + // location locset (plus the merging across cells by groups) + // Post: + // - streams[mech_id] contains a list of all events for synapse `mech_id` s.t. + // * the list is sorted by (time_step, lid, time) + // * the list is partitioned by `time_step` via `ev_spans` template friend void initialize(const event_lane_subrange& lanes, const std::vector& handles, const std::vector& divs, const timestep_range& steps, - std::unordered_map& streams) { + std::unordered_map& streams, + task_system_handle ts) { arb_assert(lanes.size() < divs.size()); // reset streams and allocate sufficient space for temporaries auto n_steps = steps.size(); - for (auto& [k, v]: streams) { - v.clear(); - v.spike_counter_.clear(); - v.spike_counter_.resize(steps.size(), 0); - v.spikes_.clear(); + for (auto& [id, stream]: streams) { + stream.clear(); + stream.spike_counter_.clear(); + stream.spike_counter_.resize(steps.size(), 0); + stream.spikes_.clear(); // ev_data_ has been cleared during v.clear(), so we use its capacity - v.spikes_.reserve(v.ev_data_.capacity()); + stream.spikes_.reserve(stream.ev_data_.capacity()); + // record sizes of streams for later merging + // + // The idea here is that this records the division points `pd` where + // `stream` was updated by the lane `lid`. As events within one lane are + // sorted, we known that events between two division points are sorted. + // Then, we can use `merge_inplace` over `sort` for a small but noticeable + // speed-up. + stream.lane_spans_.resize(lanes.size() + 1); + for (auto& ix: stream.lane_spans_) ix = stream.spikes_.size(); } // loop over lanes: group events by mechanism and sort them by time auto cell = 0; for (const auto& lane: lanes) { auto div = divs[cell]; - ++cell; arb_size_type step = 0; for (const auto& evt: lane) { auto time = evt.time; @@ -100,28 +133,39 @@ struct spike_event_stream_base : event_stream_base { const auto& handle = handles[div + target]; auto& stream = streams[handle.mech_id]; stream.spikes_.push_back(spike_data{step, handle.mech_index, time, weight}); - // insertion sort with last element as pivot - // ordering: first w.r.t. step, within a step: mech_index, within a mech_index: time - auto first = stream.spikes_.begin(); - auto last = stream.spikes_.end(); - auto pivot = std::prev(last, 1); - std::rotate(std::upper_bound(first, pivot, *pivot), pivot, last); - // increment count in current time interval stream.spike_counter_[step]++; } + // record current sizes here. putting this into the above loop is slower. significantly + for (auto& [id, stream]: streams) stream.lane_spans_[cell + 1] = stream.spikes_.size(); + ++cell; } + // parallelise over streams + auto tg = threading::task_group(ts.get()); for (auto& [id, stream]: streams) { - // copy temporary deliverable_events into stream's ev_data_ - stream.ev_data_.reserve(stream.spikes_.size()); - std::transform(stream.spikes_.begin(), stream.spikes_.end(), std::back_inserter(stream.ev_data_), - [](auto const& e) noexcept -> arb_deliverable_event_data { - return {e.mech_index, e.weight}; }); - // scan over spike_counter_ and written to ev_spans_ - util::make_partition(stream.ev_spans_, stream.spike_counter_); - // delegate to derived class init: static cast necessary to access protected init() - static_cast(stream).init(); + tg.run([&stream]() { + // scan over spike_counter_ + util::make_partition(stream.ev_spans_, stream.spike_counter_); + // leverage our earlier partitioning to merge the partitions + // theoretically, this could be parallelised, too, practically it didn't pay off + auto& part = stream.lane_spans_; + for (size_t ix = 0; ix < part.size() - 1; ++ix) { + std::inplace_merge(stream.spikes_.begin(), + stream.spikes_.begin() + part[ix], + stream.spikes_.begin() + part[ix + 1]); + } + // Further optimisation: merge(!) merging, transforming, and appending into one + // call. + // copy temporary deliverable_events into stream's ev_data_ + stream.ev_data_.reserve(stream.spikes_.size()); + std::transform(stream.spikes_.begin(), stream.spikes_.end(), + std::back_inserter(stream.ev_data_), + [](auto const& e) noexcept -> arb_deliverable_event_data { return {e.mech_index, e.weight}; }); + // delegate to derived class init: static cast necessary to access protected init() + static_cast(stream).init(); + }); } + tg.wait(); } protected: // members diff --git a/arbor/backends/shared_state_base.hpp b/arbor/backends/shared_state_base.hpp index 0c1742a0d..b463e4ad0 100644 --- a/arbor/backends/shared_state_base.hpp +++ b/arbor/backends/shared_state_base.hpp @@ -32,10 +32,11 @@ struct shared_state_base { const std::vector>& samples, const timestep_range& dts, const std::vector& handles, - const std::vector& divs) { + const std::vector& divs, + task_system_handle ts) { auto d = static_cast(this); // events - initialize(lanes, handles, divs, dts, d->streams); + initialize(lanes, handles, divs, dts, d->streams, ts); // samples auto n_samples = util::sum_by(samples, [] (const auto& s) {return s.size();}); if (d->sample_time.size() < n_samples) { diff --git a/arbor/fvm_lowered_cell.hpp b/arbor/fvm_lowered_cell.hpp index 896380c01..709253f77 100644 --- a/arbor/fvm_lowered_cell.hpp +++ b/arbor/fvm_lowered_cell.hpp @@ -231,9 +231,7 @@ struct fvm_initialization_data { struct fvm_lowered_cell { virtual void reset() = 0; - virtual fvm_initialization_data initialize( - const std::vector& gids, - const recipe& rec) = 0; + virtual fvm_initialization_data initialize(const std::vector& gids, const recipe& rec) = 0; virtual fvm_integration_result integrate(const timestep_range& dts, const event_lane_subrange& event_lanes, @@ -249,8 +247,7 @@ struct fvm_lowered_cell { using fvm_lowered_cell_ptr = std::unique_ptr; -ARB_ARBOR_API fvm_lowered_cell_ptr make_fvm_lowered_cell(backend_kind p, const execution_context& ctx, - std::uint64_t seed = 0); +ARB_ARBOR_API fvm_lowered_cell_ptr make_fvm_lowered_cell(backend_kind p, const execution_context& ctx, std::uint64_t seed = 0); inline void serialize(serializer& s, const std::string& k, const fvm_lowered_cell& v) { v.t_serialize(s, k); } diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index adcbab1a9..d2a5375e7 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -42,13 +42,11 @@ struct fvm_lowered_cell_impl: public fvm_lowered_cell { fvm_lowered_cell_impl(execution_context ctx, arb_seed_type seed = 0): context_(ctx), seed_{seed} - {}; + {} void reset() override; - fvm_initialization_data initialize( - const std::vector& gids, - const recipe& rec) override; + fvm_initialization_data initialize(const std::vector& gids, const recipe& rec) override; fvm_integration_result integrate(const timestep_range& dts, const event_lane_subrange& event_lanes, @@ -176,7 +174,7 @@ fvm_integration_result fvm_lowered_cell_impl::integrate(const timestep_ // Integration setup PE(advance:integrate:setup); // Push samples and events down to the state and reset the spike thresholds. - state_->begin_epoch(event_lanes, staged_samples, dts, target_handles_, target_handle_divisions_); + state_->begin_epoch(event_lanes, staged_samples, dts, target_handles_, target_handle_divisions_, context_.thread_pool); PL(); // loop over timesteps