Skip to content

Commit

Permalink
📊 Improve spike delivery (#2222)
Browse files Browse the repository at this point in the history
# Introduction 

Issue #502 claims that under certain circumstances, spikes walking can
become a bottleneck.
L2 cache misses are posited as the root cause. The setup required for
this is pretty extreme
though.

# Baseline
Selected to show the problem while being feasible to run on a laptop.

Input
```json
{
    "name": "test",
    "num-cells": 256,
    "duration": 400,
    "min-delay": 10,
    "fan-in": 10000,
    "realtime-ratio": 0.01,
    "spike-frequency": 40,
    "threads": 4,
    "ranks": 8000
}
```

Timing
```
❯ hyperfine 'bin/drybench ~/src/arbor/example/drybench/params.json'
Benchmark 1: bin/drybench ~/src/arbor/example/drybench/params.json
  Time (mean ± σ):      4.334 s ±  0.146 s    [User: 6.026 s, System: 0.418 s]
  Range (min … max):    4.148 s …  4.667 s    10 runs
```

# Changes

- Store connection list as structure of arrays instead of array of
structures. This reduces cache misses during binary search for the
correct source.
- Use `lower bound` instead of `equal range`. Removes one binary search,
which is cache-unfriendly.
- Treat all spikes from the same source in one go. Keeps all values
around instead of discarding and re-acquiring.
- Swap member order in `spike_event` reduces size from 24 to 16 bytes.

# Outcome

We get some minor reduction in runtime on my local machine
```
❯ hyperfine 'bin/drybench ../example/drybench/params.json'
Benchmark 1: bin/drybench ../example/drybench/params.json
  Time (mean ± σ):      4.225 s ±  0.167 s    [User: 5.939 s, System: 0.397 s]
  Range (min … max):    4.064 s …  4.632 s    10 runs
```

4.064s vs 4.148s, but this is still within one $\sigma$.

# Routes not taken

## Using a faster `sort`
Tried `pdqsort`, but sorting isn't a real bottleneck and/or `pdq`
doesn't
improve on our problem.

## Using a faster `lower_bound`
Similarly to `sort` no variation of `lower_bound` improves measurably.
We can conclude that the enqueuing is bound by the actual pushing of
events into cells' queues. As noted in #502, L2 cache misses are a
probable
root cause.

## Building a temporary buffer of events
Tried to create a scratch space to dump all events into, keyed by their
queue index.
To avoid allocations, the scratch buffer 
Then, we sort this by index and build the queues in one go. Proved a
significant
slow-down.

## SoA-Splitting Spikes
Doesn't improve our main bottleneck and `spike` is pretty small (=not
much
waste on a cacheline) already.

## Multithreading the appends
That'll only worsen the amount of cache misses as L2 is shared between
threads.
Also if cache misses are our problem, this won't address the root cause.

## Adding a pre-filter to spike processing
Reducing the incoming spikes by tracking all sources terminating at the
local process
didn't yield an improvement while rejecting ~25% of all incoming event.
Instead, a significant
slow-down was observed.

# Juwels Booster
## Input Deck and Configuration
- JuwelsBooster develbooster queue
- CMake
  ```
cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DARB_USE_BUNDLED_LIBS=ON
-DARB_VECTORIZE=ON -DARB_WITH_PYTHON=OFF -DARB_WITH_MPI=ON
-DBUILD_SHARED_LIBS=OFF -DCMAKE_INTERPROCEDURAL_OPTIMIZATION=ON
-DARB_PROFILIING=ON -DCMAKE_INSTALL_PREFIX=../install -G Ninja
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DARB_GPU=cuda
-DCMAKE_BUILD_TYPE=release
  ```
- Input
  ```json
  {
    "name": "test",
    "num-cells": 8000,
    "duration": 200,
    "min-delay": 10,
    "fan-in": 5000,
    "realtime-ratio": 0.01,
    "spike-frequency": 50,
    "threads": 4,
    "ranks": 10000
  }
  ```

## Validation
```
benchmark parameters:
  name:           test
  cells per rank: 8000
  duration:       200 ms
  fan in:         5000 connections/cell
  min delay:      10 ms
  spike freq:     50 Hz
  cell overhead:  0.01 ms to advance 1 ms
expected:
  cell advance:   16 s
  spikes:         800000000
  events:         3204710400
  spikes:         2000 per interval
  events:         151969 per cell per interval
HW resources:
  threads:        4
  ranks:          10000
```
and
```
808110000 spikes generated at rate of 10 spikes per cell
```
## Baseline
### Summary
```
---- meters -------------------------------------------------------------------------------
meter                         time(s)      memory(MB)
-------------------------------------------------------------------------------------------
model-init                    113.175         288.951
model-run                     107.741        1306.015
meter-total                   220.916        1594.966
```

### Profiler
```
REGION                          CALLS      WALL     THREAD        %
root                                -    78.541    314.163    100.0
  communication                     -    36.991    147.964     47.1
    enqueue                         -    21.247     84.986     27.1
      sort                     320000    19.952     79.807     25.4
      merge                    320000     1.174      4.695      1.5
      setup                    320000     0.121      0.483      0.2
    walkspikes                     40    15.735     62.940     20.0
    exchange                        -     0.010      0.038      0.0
      sort                         40     0.006      0.022      0.0
      gatherlocal                  40     0.004      0.016      0.0
      gather                       40     0.000      0.000      0.0
        remote                     40     0.000      0.000      0.0
          post_process             40     0.000      0.000      0.0
```

### Perf
```
 Performance counter stats for 'bin/drybench ../example/drybench/params.json':

         245572.30 msec task-clock:u              #    1.000 CPUs utilized
                 0      context-switches:u        #    0.000 /sec
                 0      cpu-migrations:u          #    0.000 /sec
          29077724      page-faults:u             #  118.408 K/sec
      730063012768      cycles:u                  #    2.973 GHz                      (83.33%)
       25160708320      stalled-cycles-frontend:u #    3.45% frontend cycles idle     (83.33%)
      332882140632      stalled-cycles-backend:u  #   45.60% backend cycles idle      (83.33%)
     1080323993228      instructions:u            #    1.48  insn per cycle
                                                  #    0.31  stalled cycles per insn  (83.33%)
      243594581289      branches:u                #  991.946 M/sec                    (83.33%)
        3447791949      branch-misses:u           #    1.42% of all branches          (83.33%)

     245.609154303 seconds time elapsed

     218.071863000 seconds user
      25.114123000 seconds sys
```

## Feature Branch
### Summary
```
---- meters -------------------------------------------------------------------------------
meter                         time(s)      memory(MB)
-------------------------------------------------------------------------------------------
model-init                    112.901         939.580
model-run                      84.730         871.134
meter-total                   197.631        1810.714
```

### Profiler
```
REGION                              CALLS      WALL     THREAD        %
root                                    -    71.717    286.869    100.0
  communication                         -    30.408    121.633     42.4
    enqueue                             -    20.006     80.023     27.9
      sort                         320000    18.984     75.938     26.5
      merge                        320000     0.937      3.746      1.3
      setup                        320000     0.085      0.340      0.1
    walkspikes                         40    10.401     41.605     14.5
    exchange                            -     0.001      0.005      0.0
      sort                             40     0.001      0.004      0.0
      gatherlocal                      40     0.000      0.001      0.0
      gather                           40     0.000      0.000      0.0
        remote                         40     0.000      0.000      0.0
          post_process                 40     0.000      0.000      0.0
    spikeio                            40     0.000      0.000      0.0
```

### Perf
```
 Performance counter stats for 'bin/drybench /p/project/cslns/hater1/arbor/example/drybench/params.json':

         221852.22 msec task-clock:u              #    1.000 CPUs utilized
                 0      context-switches:u        #    0.000 /sec
                 0      cpu-migrations:u          #    0.000 /sec
          28014394      page-faults:u             #  126.275 K/sec
      658832257282      cycles:u                  #    2.970 GHz                      (83.33%)
       49927962696      stalled-cycles-frontend:u #    7.58% frontend cycles idle     (83.33%)
      285682904743      stalled-cycles-backend:u  #   43.36% backend cycles idle      (83.33%)
      943000553536      instructions:u            #    1.43  insn per cycle
                                                  #    0.30  stalled cycles per insn  (83.33%)
      212435336483      branches:u                #  957.553 M/sec                    (83.33%)
        3372189010      branch-misses:u           #    1.59% of all branches          (83.33%)

     221.907422663 seconds time elapsed

     195.279197000 seconds user
      24.508514000 seconds sys
```

## Sorting events by `time` only

w/ pdqsort
```
---- meters -------------------------------------------------------------------------------
meter                         time(s)      memory(MB)
-------------------------------------------------------------------------------------------
model-init                    111.256         939.580
model-run                      79.065         871.150
meter-total                   190.321        1810.730
```
with `util::sort_by`
```
---- meters -------------------------------------------------------------------------------
meter                         time(s)      memory(MB)
-------------------------------------------------------------------------------------------
model-init                    111.666         939.581
model-run                      78.900         871.131
meter-total                   190.565        1810.712
```

### Profiler
```
REGION                              CALLS      WALL     THREAD        %
root                                    -    68.611    274.442    100.0
  communication                         -    27.576    110.306     40.2
    enqueue                             -    17.728     70.912     25.8
      sort                         320000    16.853     67.410     24.6
      merge                        320000     0.808      3.231      1.2
      setup                        320000     0.068      0.270      0.1
    walkspikes                         40     9.848     39.391     14.4
```

# Conclusion 
We find a 30% decrease in time spent on spike walking and a 10% decrease
end-to-end. Note that
this case is extremely heavy on the communication part and advances
cells at 100x _realtime_ ie
1ms wallclock for 100ms biological time which is far beyond any cable
cell we encounter in the wild.

By just sorting on the `time` field -- we don't care about the ordering
beyond that -- we find another
5%.

## Side note: Memory measurement
The memory measurement isn't trustworthy, the 290MB reported above for
the baseline turned into
1300MB on repeating the benchmark.

# Related Issues

Closes #502
  • Loading branch information
thorstenhater authored Nov 21, 2023
1 parent 65290d6 commit 29d84bd
Show file tree
Hide file tree
Showing 12 changed files with 1,347 additions and 141 deletions.
145 changes: 84 additions & 61 deletions arbor/communication/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void communicator::update_connections(const connectivity& rec,
PE(init:communicator:update:clear);
// Forget all lingering information
connections_.clear();
ext_connections_.clear();
connection_part_.clear();
index_divisions_.clear();
PL();
Expand Down Expand Up @@ -123,8 +124,8 @@ void communicator::update_connections(const connectivity& rec,
// to do this in place.
// NOTE: The connections are partitioned by the domain of their source gid.
PE(init:communicator:update:connections);
connections_.resize(n_cons);
ext_connections_.resize(n_ext_cons);
std::vector<connection> connections(n_cons);
std::vector<connection> ext_connections(n_ext_cons);
auto offsets = connection_part_; // Copy, as we use this as the list of current target indices to write into
std::size_t ext = 0;
auto src_domain = src_domains.begin();
Expand All @@ -140,15 +141,15 @@ void communicator::update_connections(const connectivity& rec,
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
auto offset = offsets[*src_domain]++;
++src_domain;
connections_[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
}
for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) {
const auto& conn = gid_ext_connections[cidx];
auto src = global_cell_of(conn.source);
auto src_gid = conn.source.rid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
ext_connections_[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
++ext;
}
}
Expand All @@ -168,9 +169,14 @@ void communicator::update_connections(const connectivity& rec,
const auto& cp = connection_part_;
threading::parallel_for::apply(0, num_domains_, thread_pool_.get(),
[&](cell_size_type i) {
util::sort(util::subrange_view(connections_, cp[i], cp[i+1]));
util::sort(util::subrange_view(connections, cp[i], cp[i+1]));
});
std::sort(ext_connections_.begin(), ext_connections_.end());
std::sort(ext_connections.begin(), ext_connections.end());
PL();

PE(init:communicator:update:destructure_connections);
connections_.make(connections);
ext_connections_.make(ext_connections);
PL();
}

Expand All @@ -181,12 +187,12 @@ std::pair<cell_size_type, cell_size_type> communicator::group_queue_range(cell_s

time_type communicator::min_delay() {
time_type res = std::numeric_limits<time_type>::max();
res = std::accumulate(connections_.begin(), connections_.end(),
res,
[](auto&& acc, auto&& el) { return std::min(acc, time_type(el.delay)); });
res = std::accumulate(ext_connections_.begin(), ext_connections_.end(),
res,
[](auto&& acc, auto&& el) { return std::min(acc, time_type(el.delay)); });
res = std::accumulate(connections_.delays.begin(), connections_.delays.end(),
res,
[](auto&& acc, time_type del) { return std::min(acc, del); });
res = std::accumulate(ext_connections_.delays.begin(), ext_connections_.delays.end(),
res,
[](auto&& acc, time_type del) { return std::min(acc, del); });
res = distributed_->min(res);
return res;
}
Expand Down Expand Up @@ -228,19 +234,38 @@ void communicator::set_remote_spike_filter(const spike_predicate& p) { remote_sp
void communicator::remote_ctrl_send_continue(const epoch& e) { distributed_->remote_ctrl_send_continue(e); }
void communicator::remote_ctrl_send_done() { distributed_->remote_ctrl_send_done(); }

// Given
// * a set of connections and an index into the set
// * a range of spikes
// * an output queue,
// append events for that sub-range of spikes to the
// queue that has the same source as the connection
// at index.
template<typename It>
void enqueue_from_source(const communicator::connection_list& cons,
const size_t idx,
It& spk,
const It end,
std::vector<pse_vector>& out) {
// const refs to connection.
auto src = cons.srcs[idx];
auto dst = cons.dests[idx];
auto del = cons.delays[idx];
auto wgt = cons.weights[idx];
auto dom = cons.idx_on_domain[idx];
auto& que = out[dom];
for (; spk != end && spk->source == src; ++spk) {
que.emplace_back(dst, spk->time + del, wgt);
}
}

// Internal helper to append to the event queues
template<typename S, typename C>
void append_events_from_domain(C cons,
S spks,
template<typename S>
void append_events_from_domain(const communicator::connection_list& cons, size_t cn, const size_t ce,
const S& spks,
std::vector<pse_vector>& queues) {
// Predicate for partitioning
struct spike_pred {
bool operator()(const spike& spk, const cell_member_type& src) { return spk.source < src; }
bool operator()(const cell_member_type& src, const spike& spk) { return src < spk.source; }
};

auto sp = spks.begin(), se = spks.end();
auto cn = cons.begin(), ce = cons.end();
if (se == sp) return;
// We have a choice of whether to walk spikes or connections:
// i.e., we can iterate over the spikes, and for each spike search
// the for connections that have the same source; or alternatively
Expand All @@ -251,64 +276,62 @@ void append_events_from_domain(C cons,
// complexity of order max(S log(C), C log(S)), where S is the
// number of spikes, and C is the number of connections.
if (cons.size() < spks.size()) {
while (cn != ce && sp != se) {
auto sources = std::equal_range(sp, se, cn->source, spike_pred());
for (auto s: util::make_range(sources)) {
queues[cn->index_on_domain].push_back(make_event(*cn, s));
}
sp = sources.first;
++cn;
for (; sp != se && cn < ce; ++cn) {
// sp is now the beginning of a range of spikes from the same
// source.
sp = std::lower_bound(sp, se,
cons.srcs[cn],
[](const auto& spk, const auto& src) { return spk.source < src; });
// now, sp is at the end of the equal source range.
enqueue_from_source(cons, cn, sp, se, queues);
}
}
else {
while (cn != ce && sp != se) {
auto targets = std::equal_range(cn, ce, sp->source);
for (auto c: util::make_range(targets)) {
queues[c.index_on_domain].push_back(make_event(c, *sp));
while (sp != se) {
auto beg = sp;
auto src = beg->source;
// Here, `cn` is the index of the first connection whose source
// is larger or equal to the spike's source. It may be `ce` if
// all elements compare < to spk.source.
cn = std::lower_bound(cons.srcs.begin() + cn,
cons.srcs.begin() + ce,
src)
- cons.srcs.begin();
for (; cn < ce && cons.srcs[cn] == src; ++cn) {
// Reset the spike iterator as we walk the same sub-range
// for each connection with the same source.
sp = beg;
// If we ever get multiple spikes from the same source, treat
// them all. This is mostly rare.
enqueue_from_source(cons, cn, sp, se, queues);
}
cn = targets.first;
++sp;
while (sp != se && sp->source == src) ++sp;
}
}
}

void communicator::make_event_queues(
const gathered_vector<spike>& global_spikes,
std::vector<pse_vector>& queues,
const std::vector<spike>& external_spikes) {
void communicator::make_event_queues(communicator::spikes& spikes,
std::vector<pse_vector>& queues) {
arb_assert(queues.size()==num_local_cells_);
const auto& sp = global_spikes.partition();
const auto& sp = spikes.from_local.partition();
const auto& cp = connection_part_;
for (auto dom: util::make_span(num_domains_)) {
append_events_from_domain(util::subrange_view(connections_, cp[dom], cp[dom+1]),
util::subrange_view(global_spikes.values(), sp[dom], sp[dom+1]),
append_events_from_domain(connections_, cp[dom], cp[dom+1],
util::subrange_view(spikes.from_local.values(), sp[dom], sp[dom+1]),
queues);
}
num_local_events_ = util::sum_by(queues, [](const auto& q) {return q.size();}, num_local_events_);
// Now that all local spikes have been processed; consume the remote events coming in.
// - turn all gids into externals
auto spikes = external_spikes;
std::for_each(spikes.begin(),
spikes.end(),
std::for_each(spikes.from_remote.begin(), spikes.from_remote.end(),
[](auto& s) { s.source = global_cell_of(s.source); });
append_events_from_domain(ext_connections_, spikes, queues);
append_events_from_domain(ext_connections_, 0, ext_connections_.size(), spikes.from_remote, queues);
}

std::uint64_t communicator::num_spikes() const {
return num_spikes_;
}

void communicator::set_num_spikes(std::uint64_t n) {
num_spikes_ = n;
}

cell_size_type communicator::num_local_cells() const {
return num_local_cells_;
}

const std::vector<connection>& communicator::connections() const {
return connections_;
}
std::uint64_t communicator::num_spikes() const { return num_spikes_; }
void communicator::set_num_spikes(std::uint64_t n) { num_spikes_ = n; }
cell_size_type communicator::num_local_cells() const { return num_local_cells_; }
const communicator::connection_list& communicator::connections() const { return connections_; }

void communicator::reset() {
num_spikes_ = 0;
Expand Down
45 changes: 37 additions & 8 deletions arbor/communication/communicator.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <vector>
#include <unordered_set>

#include <arbor/export.hpp>
#include <arbor/common_types.hpp>
Expand Down Expand Up @@ -63,19 +64,14 @@ class ARB_ARBOR_API communicator {
/// all events that must be delivered to targets in that cell group as a
/// result of the global spike exchange, plus any events that were already
/// in the list.
void make_event_queues(
const gathered_vector<spike>& global_spikes,
std::vector<pse_vector>& queues,
const std::vector<spike>& external_spikes={});
void make_event_queues(spikes& spks, std::vector<pse_vector>& queues);

/// Returns the total number of global spikes over the duration of the simulation
std::uint64_t num_spikes() const;
void set_num_spikes(std::uint64_t n);

cell_size_type num_local_cells() const;

const std::vector<connection>& connections() const;

void reset();

// used for commmunicate to coupled simulations
Expand All @@ -89,13 +85,46 @@ class ARB_ARBOR_API communicator {

void set_remote_spike_filter(const spike_predicate&);

// TODO: This is public for now.
struct connection_list {
std::vector<cell_size_type> idx_on_domain;
std::vector<cell_member_type> srcs;
std::vector<cell_lid_type> dests;
std::vector<float> weights;
std::vector<float> delays;

void make(const std::vector<connection>& cons) {
clear();
for (const auto& con: cons) {
idx_on_domain.push_back(con.index_on_domain);
srcs.push_back(con.source);
dests.push_back(con.destination);
weights.push_back(con.weight);
delays.push_back(con.delay);
}
}

void clear() {
idx_on_domain.clear();
srcs.clear();
dests.clear();
weights.clear();
delays.clear();
}

size_t size() const { return srcs.size(); }
};

const connection_list& connections() const;

private:

cell_size_type num_total_cells_ = 0;
cell_size_type num_local_cells_ = 0;
cell_size_type num_local_groups_ = 0;
cell_size_type num_domains_ = 0;
// Arbor internal connections
std::vector<connection> connections_;
connection_list connections_;
// partition of connections over the domains of the sources' ids.
std::vector<cell_size_type> connection_part_;
std::vector<cell_size_type> index_divisions_;
Expand All @@ -105,7 +134,7 @@ class ARB_ARBOR_API communicator {

// Connections from external simulators into Arbor.
// Currently we have no partitions/indices/acceleration structures
std::vector<connection> ext_connections_;
connection_list ext_connections_;

distributed_context_handle distributed_;
task_system_handle thread_pool_;
Expand Down
20 changes: 10 additions & 10 deletions arbor/include/arbor/spike_event.hpp
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
#pragma once

#include <arbor/arb_types.hpp>

#include <iosfwd>
#include <tuple>
#include <vector>

#include <arbor/export.hpp>
#include <arbor/serdes.hpp>
#include <arbor/common_types.hpp>
#include <arbor/util/lexcmp_def.hpp>

namespace arb {

// Events delivered to targets on cells with a cell group.

struct spike_event {
cell_lid_type target;
time_type time;
float weight;

friend bool operator==(const spike_event& l, const spike_event& r) {
return l.target==r.target && l.time==r.time && l.weight==r.weight;
}
cell_lid_type target = -1;
float weight = 0;
time_type time = -1;

friend bool operator<(const spike_event& l, const spike_event& r) {
return std::tie(l.time, l.target, l.weight) < std::tie(r.time, r.target, r.weight);
}
spike_event() = default;
constexpr spike_event(cell_lid_type tgt, time_type t, arb_weight_type w) noexcept: target(tgt), weight(w), time(t) {}

ARB_SERDES_ENABLE(spike_event, target, time, weight);
};

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(spike_event,(a.time,a.target,a.weight),(b.time,b.target,b.weight))

using pse_vector = std::vector<spike_event>;

struct cell_spike_events {
Expand Down
Loading

0 comments on commit 29d84bd

Please sign in to comment.