Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: High-level network specification #2050

Merged
merged 98 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 95 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
cc609bf
replace domain index with gid in connection struct
AdhocMan Feb 27, 2023
dac34ad
wip
AdhocMan Mar 3, 2023
76109ba
reset label resolution
AdhocMan Mar 6, 2023
27798a1
network_impl
AdhocMan Mar 6, 2023
ba9a199
simplify
AdhocMan Mar 6, 2023
525a50d
recv send
AdhocMan Mar 8, 2023
fae2651
complete generation
AdhocMan Mar 11, 2023
f536c13
implemented construction
AdhocMan Mar 12, 2023
ab02df6
support for named selection / value
AdhocMan Mar 12, 2023
d1f1e41
add label parsing (incomplete)
AdhocMan Mar 14, 2023
5bb3244
first python interface, fix string lifetime
AdhocMan Mar 17, 2023
b9dcfa6
label parse
AdhocMan Mar 17, 2023
16dc1e9
add rotation / translation, more selections
AdhocMan Mar 25, 2023
edae2d9
use mpoint
AdhocMan Mar 25, 2023
6882c59
all labels for selection
AdhocMan Mar 25, 2023
c90dd98
value labels
AdhocMan Mar 26, 2023
1593e28
implemented other cell types
AdhocMan Mar 26, 2023
443ea8e
now using network description to generate connections
AdhocMan Mar 26, 2023
8c47e96
chain instead of ring
AdhocMan Apr 2, 2023
f1b351c
rename selection labels
AdhocMan Apr 2, 2023
b993f2f
gid_range added to common types
AdhocMan Apr 11, 2023
35dc0ed
renaming and test label parsing
AdhocMan Apr 13, 2023
7b0dbf1
add math operation to network value
AdhocMan Apr 13, 2023
343218b
add network unit tests
AdhocMan Apr 20, 2023
229da63
move label parsing to new files
AdhocMan Apr 21, 2023
ce75385
network_value operators
AdhocMan May 2, 2023
2fec3f1
examples
AdhocMan May 2, 2023
2eb1c90
doc update
AdhocMan May 2, 2023
edb90bf
fix mpi compilation
AdhocMan May 11, 2023
2ab0d89
Add if-else network value
AdhocMan May 15, 2023
75fa66a
generalized ring exchange
AdhocMan May 18, 2023
f666941
add network connection export and multi-threading
AdhocMan Jul 2, 2023
15516b0
add tree and distributed for each tests
AdhocMan Jul 24, 2023
9cbabec
add network generation test
AdhocMan Jul 26, 2023
508bcc6
more effective multi-threading
AdhocMan Jul 28, 2023
f82fb2d
merge master
AdhocMan Aug 15, 2023
fc38ec8
fix mpi compilation
AdhocMan Aug 15, 2023
a5bf351
documentation
AdhocMan Aug 22, 2023
8dae792
api documentation
AdhocMan Aug 24, 2023
5ac7c5c
Merge remote-tracking branch 'upstream/master' into network
AdhocMan Aug 25, 2023
4b33622
reformat example
AdhocMan Aug 25, 2023
6e44819
fix flake8 warning
AdhocMan Aug 25, 2023
d8ae7d2
fix unit test linking with shared library
AdhocMan Aug 25, 2023
d21b51d
add connection print out to examples
AdhocMan Aug 28, 2023
bbdc519
add new examples to scripts
AdhocMan Aug 28, 2023
25086f1
reformat example
AdhocMan Aug 28, 2023
30ccb9b
fix ci script value
AdhocMan Aug 28, 2023
c8421fd
Switch std::string -> hashes for label resolution.
thorstenhater Aug 29, 2023
d13a5b1
Invariant checks for hash collisions.
thorstenhater Aug 29, 2023
ee037fc
Account for same keys on cell label range.
thorstenhater Aug 29, 2023
1a74723
review
AdhocMan Aug 31, 2023
71d116b
added "visit_variant" alternative to std::visit for better performance
AdhocMan Sep 1, 2023
d4e334e
improved multithreading by lock removal
AdhocMan Sep 2, 2023
76d44ba
move mapping of gid to local domain index into domain decomposition
AdhocMan Sep 3, 2023
365f9e3
fix communicator test
AdhocMan Sep 3, 2023
11ae954
rename destination to target
AdhocMan Sep 3, 2023
7e1aaf1
fix test
AdhocMan Sep 4, 2023
3dd836c
revert to float type in spike event
AdhocMan Sep 4, 2023
55a376f
python reformatting
AdhocMan Sep 4, 2023
384dff7
remove reference from string_view
AdhocMan Sep 6, 2023
8339d57
Merge remote-tracking branch 'upstream/master' into network
AdhocMan Sep 6, 2023
f008f72
Push hashes up into decor.
thorstenhater Sep 12, 2023
eb2c7fb
Merge remote-tracking branch 'origin/master' into perf/hashed-labels
thorstenhater Sep 12, 2023
d8546fc
Map back to labels.
thorstenhater Sep 12, 2023
b8d0329
Merge remote-tracking branch 'thorsten/perf/hashed-labels' into netwo…
AdhocMan Sep 13, 2023
c94f1c2
use hashes
AdhocMan Sep 19, 2023
cf2a5f8
Add internal_hash to hash_def and integrate w/ hash_value.
thorstenhater Oct 17, 2023
83e6c44
Merge remote-tracking branch 'thorsten/perf/hashed-labels' into netwo…
AdhocMan Oct 20, 2023
682a6aa
Add some testing, treat pointers.
thorstenhater Nov 22, 2023
fbdf107
Merge remote-tracking branch 'origin/master' into perf/hashed-labels
thorstenhater Nov 22, 2023
ef792c9
Merge part two.
thorstenhater Nov 22, 2023
97ca7a4
The missing test.
thorstenhater Nov 22, 2023
61bd0b4
merge label hash branch
AdhocMan Nov 28, 2023
0cec44f
Merge remote-tracking branch 'origin/master' into perf/hashed-labels
thorstenhater Nov 29, 2023
c583b2f
Shuffle internal hash and combine into a detail namespace.
thorstenhater Nov 29, 2023
297dea9
merge label hash branch 2
AdhocMan Nov 29, 2023
c37f603
merge master
AdhocMan Dec 20, 2023
1b3bbb8
use cbprng random generator definition
AdhocMan Dec 20, 2023
a96cbf1
refactor
AdhocMan Dec 22, 2023
8d2983f
Merge remote-tracking branch 'upstream/master' into network_hash
AdhocMan Dec 22, 2023
dec80fa
fix example
AdhocMan Dec 22, 2023
f9db23a
merge
AdhocMan Jan 23, 2024
0e75615
fix python example
AdhocMan Jan 23, 2024
20902f1
change back pybind11 submodule
AdhocMan Apr 5, 2024
1e07cdd
add noexcept to spatial_tree
AdhocMan Apr 5, 2024
fc444f6
improved error message
AdhocMan Apr 8, 2024
654d935
fix doc for random selection
AdhocMan Apr 8, 2024
55eff36
replace pybind arg_v
AdhocMan Apr 8, 2024
43de6a1
doc
AdhocMan Apr 8, 2024
1596219
distributed for each doc
AdhocMan Apr 9, 2024
d178b07
move call out of inner loop
AdhocMan Apr 9, 2024
1bc1807
merge master
AdhocMan Apr 9, 2024
23e9659
fix warning
AdhocMan Apr 9, 2024
953b1ce
noexcept fix attempt
AdhocMan Apr 9, 2024
6ad6c30
fix dummy context test
AdhocMan Apr 9, 2024
340da63
revert std::swap usage
AdhocMan Apr 10, 2024
534c559
use lowest()
AdhocMan Apr 10, 2024
31fbbe4
add documentation for generate_network_connections function
AdhocMan May 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions arbor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ set(arbor_sources
morph/segment_tree.cpp
morph/stitch.cpp
merge_events.cpp
network.cpp
network_impl.cpp
simulation.cpp
partition_load_balance.cpp
profile/clock.cpp
Expand Down
53 changes: 35 additions & 18 deletions arbor/communication/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "connection.hpp"
#include "distributed_context.hpp"
#include "execution_context.hpp"
#include "network_impl.hpp"
#include "profile/profiler_macro.hpp"
#include "threading/threading.hpp"
#include "util/partition.hpp"
Expand All @@ -24,14 +25,12 @@

namespace arb {

communicator::communicator(const recipe& rec,
const domain_decomposition& dom_dec,
execution_context& ctx): num_total_cells_{rec.num_cells()},
num_local_cells_{dom_dec.num_local_cells()},
num_local_groups_{dom_dec.num_groups()},
num_domains_{(cell_size_type) ctx.distributed->size()},
distributed_{ctx.distributed},
thread_pool_{ctx.thread_pool} {}
communicator::communicator(const recipe& rec, const domain_decomposition& dom_dec, context ctx):
num_total_cells_{rec.num_cells()},
num_local_cells_{dom_dec.num_local_cells()},
num_local_groups_{dom_dec.num_groups()},
num_domains_{(cell_size_type)ctx->distributed->size()},
ctx_(std::move(ctx)) {}

constexpr inline
bool is_external(cell_gid_type c) {
Expand All @@ -55,7 +54,7 @@ cell_member_type global_cell_of(const cell_member_type& c) {
return {c.gid | msb, c.index};
}

void communicator::update_connections(const connectivity& rec,
void communicator::update_connections(const recipe& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
const label_resolution_map& target_resolution_map) {
Expand All @@ -67,6 +66,9 @@ void communicator::update_connections(const connectivity& rec,
index_divisions_.clear();
PL();

// Construct connections from high-level specification
auto generated_connections = generate_connections(rec, ctx_, dom_dec);
thorstenhater marked this conversation as resolved.
Show resolved Hide resolved

// Make a list of local cells' connections
// -> gid_connections
// Count the number of local connections (i.e. connections terminating on this domain)
Expand Down Expand Up @@ -114,9 +116,18 @@ void communicator::update_connections(const connectivity& rec,
}
part_ext_connections.push_back(gid_ext_connections.size());
}
for (const auto& c: generated_connections) {
auto sgid = c.source.gid;
if (sgid >= num_total_cells_) {
throw arb::bad_connection_source_gid(c.source.gid, sgid, num_total_cells_);
}
const auto src = dom_dec.gid_domain(sgid);
src_domains.push_back(src);
src_counts[src]++;
}

util::make_partition(connection_part_, src_counts);
auto n_cons = gid_connections.size();
auto n_cons = gid_connections.size() + generated_connections.size();
auto n_ext_cons = gid_ext_connections.size();
PL();

Expand All @@ -132,6 +143,7 @@ void communicator::update_connections(const connectivity& rec,
auto target_resolver = resolver(&target_resolution_map);
for (const auto index: util::make_span(num_local_cells_)) {
const auto tgt_gid = gids[index];
const auto iod = dom_dec.index_on_domain(tgt_gid);
auto source_resolver = resolver(&source_resolution_map);
for (const auto cidx: util::make_span(part_connections[index], part_connections[index+1])) {
const auto& conn = gid_connections[cidx];
Expand All @@ -141,18 +153,23 @@ void communicator::update_connections(const connectivity& rec,
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
auto offset = offsets[*src_domain]++;
++src_domain;
connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, iod};
}
for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) {
const auto& conn = gid_ext_connections[cidx];
auto src = global_cell_of(conn.source);
auto src_gid = conn.source.rid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, iod};
++ext;
}
}
for (const auto& c: generated_connections) {
auto offset = offsets[*src_domain]++;
++src_domain;
connections[offset] = c;
}
PL();

PE(init:communicator:update:index);
Expand All @@ -167,7 +184,7 @@ void communicator::update_connections(const connectivity& rec,
// Sort the connections for each domain.
// This is num_domains_ independent sorts, so it can be parallelized trivially.
const auto& cp = connection_part_;
threading::parallel_for::apply(0, num_domains_, thread_pool_.get(),
threading::parallel_for::apply(0, num_domains_, ctx_->thread_pool.get(),
[&](cell_size_type i) {
util::sort(util::subrange_view(connections, cp[i], cp[i+1]));
});
Expand All @@ -193,7 +210,7 @@ time_type communicator::min_delay() {
res = std::accumulate(ext_connections_.delays.begin(), ext_connections_.delays.end(),
res,
[](auto&& acc, time_type del) { return std::min(acc, del); });
res = distributed_->min(res);
res = ctx_->distributed->min(res);
return res;
}

Expand All @@ -206,7 +223,7 @@ communicator::exchange(std::vector<spike> local_spikes) {

PE(communication:exchange:gather);
// global all-to-all to gather a local copy of the global spike list on each node.
auto global_spikes = distributed_->gather_spikes(local_spikes);
auto global_spikes = ctx_->distributed->gather_spikes(local_spikes);
num_spikes_ += global_spikes.size();
PL();

Expand All @@ -217,7 +234,7 @@ communicator::exchange(std::vector<spike> local_spikes) {
local_spikes.end(),
[this] (const auto& s) { return !remote_spike_filter_(s); }));
}
auto remote_spikes = distributed_->remote_gather_spikes(local_spikes);
auto remote_spikes = ctx_->distributed->remote_gather_spikes(local_spikes);
PL();

PE(communication:exchange:gather:remote:post_process);
Expand All @@ -231,8 +248,8 @@ communicator::exchange(std::vector<spike> local_spikes) {
}

void communicator::set_remote_spike_filter(const spike_predicate& p) { remote_spike_filter_ = p; }
void communicator::remote_ctrl_send_continue(const epoch& e) { distributed_->remote_ctrl_send_continue(e); }
void communicator::remote_ctrl_send_done() { distributed_->remote_ctrl_send_done(); }
void communicator::remote_ctrl_send_continue(const epoch& e) { ctx_->distributed->remote_ctrl_send_continue(e); }
void communicator::remote_ctrl_send_done() { ctx_->distributed->remote_ctrl_send_done(); }

// Given
// * a set of connections and an index into the set
Expand Down
13 changes: 6 additions & 7 deletions arbor/communication/communicator.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#pragma once

#include <vector>
#include <unordered_set>

#include <arbor/export.hpp>
#include <arbor/common_types.hpp>
#include <arbor/context.hpp>
#include <arbor/domain_decomposition.hpp>
#include <arbor/export.hpp>
#include <arbor/recipe.hpp>
#include <arbor/spike.hpp>

Expand Down Expand Up @@ -40,7 +40,7 @@ class ARB_ARBOR_API communicator {

explicit communicator(const recipe& rec,
const domain_decomposition& dom_dec,
execution_context& ctx);
context ctx);
thorstenhater marked this conversation as resolved.
Show resolved Hide resolved

/// The range of event queues that belong to cells in group i.
std::pair<cell_size_type, cell_size_type> group_queue_range(cell_size_type i);
Expand Down Expand Up @@ -78,7 +78,7 @@ class ARB_ARBOR_API communicator {
void remote_ctrl_send_continue(const epoch&);
void remote_ctrl_send_done();

void update_connections(const connectivity& rec,
void update_connections(const recipe& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
const label_resolution_map& target_resolution_map);
Expand All @@ -98,7 +98,7 @@ class ARB_ARBOR_API communicator {
for (const auto& con: cons) {
idx_on_domain.push_back(con.index_on_domain);
srcs.push_back(con.source);
dests.push_back(con.destination);
dests.push_back(con.target);
weights.push_back(con.weight);
delays.push_back(con.delay);
}
Expand Down Expand Up @@ -136,10 +136,9 @@ class ARB_ARBOR_API communicator {
// Currently we have no partitions/indices/acceleration structures
connection_list ext_connections_;

distributed_context_handle distributed_;
task_system_handle thread_pool_;
std::uint64_t num_spikes_ = 0u;
std::uint64_t num_local_events_ = 0u;
context ctx_;
};

} // namespace arb
185 changes: 185 additions & 0 deletions arbor/communication/distributed_for_each.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#pragma once

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <memory>
#include <type_traits>
#include <utility>

#include "distributed_context.hpp"
#include "util/range.hpp"

namespace arb {

namespace impl {
template <class FUNC, typename... T, std::size_t... Is>
void for_each_in_tuple(FUNC&& func, std::tuple<T...>& t, std::index_sequence<Is...>) {
(func(Is, std::get<Is>(t)), ...);
}

template <class FUNC, typename... T>
void for_each_in_tuple(FUNC&& func, std::tuple<T...>& t) {
for_each_in_tuple(func, t, std::index_sequence_for<T...>());
}

template <class FUNC, typename... T1, typename... T2, std::size_t... Is>
void for_each_in_tuple_pair(FUNC&& func,
std::tuple<T1...>& t1,
std::tuple<T2...>& t2,
std::index_sequence<Is...>) {
(func(Is, std::get<Is>(t1), std::get<Is>(t2)), ...);
}

template <class FUNC, typename... T1, typename... T2>
void for_each_in_tuple_pair(FUNC&& func, std::tuple<T1...>& t1, std::tuple<T2...>& t2) {
for_each_in_tuple_pair(func, t1, t2, std::index_sequence_for<T1...>());
}

} // namespace impl


/*
* Collective operation, calling func on args supplied by each rank exactly once. The order of calls
* is unspecified. Requires
*
* - Item = util::range<ARGS>::value_type to be identical across all ranks
* - Item is trivially_copyable
* - Alignment of Item must not exceed std::max_align_t
* - func to be a callable type with signature
* void func(util::range<Item*>...)
* - func must not modify contents of range
* - All ranks in distributed must call this collectively.
*/
template <typename FUNC, typename... ARGS>
thorstenhater marked this conversation as resolved.
Show resolved Hide resolved
void distributed_for_each(FUNC&& func,
const distributed_context& distributed,
const util::range<ARGS>&... args) {

static_assert(sizeof...(args) > 0);
auto arg_tuple = std::forward_as_tuple(args...);

struct vec_info {
std::size_t offset; // offset in bytes
std::size_t size; // size in bytes
};

std::array<vec_info, sizeof...(args)> info;
std::size_t buffer_size = 0;

// Compute offsets in bytes for each vector when placed in common buffer
{
std::size_t offset = info.size() * sizeof(vec_info);
impl::for_each_in_tuple(
[&](std::size_t i, auto&& vec) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
static_assert(std::is_trivially_copyable_v<T>);
static_assert(alignof(std::max_align_t) >= alignof(T));
static_assert(alignof(std::max_align_t) % alignof(T) == 0);

// make sure alignment of offset fulfills requirement
const auto alignment_excess = offset % alignof(T);
offset += alignment_excess > 0 ? alignof(T) - (alignment_excess) : 0;

const auto size_in_bytes = vec.size() * sizeof(T);

info[i].size = size_in_bytes;
info[i].offset = offset;

buffer_size = offset + size_in_bytes;
offset += size_in_bytes;
},
arg_tuple);
}

// compute maximum buffer size between ranks, such that we only allocate once
const std::size_t max_buffer_size = distributed.max(buffer_size);

std::tuple<util::range<typename std::remove_reference_t<decltype(args)>::value_type*>...>
ranges;

if (max_buffer_size == info.size() * sizeof(vec_info)) {
// if all empty, call function with empty ranges for each step and exit
impl::for_each_in_tuple_pair(
[&](std::size_t i, auto&& vec, auto&& r) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
r = util::range<T*>(nullptr, nullptr);
},
arg_tuple,
ranges);

for (int step = 0; step < distributed.size(); ++step) { std::apply(func, ranges); }
return;
}

// use malloc for std::max_align_t alignment
auto deleter = [](char* ptr) { std::free(ptr); };
std::unique_ptr<char[], void (*)(char*)> buffer((char*)std::malloc(max_buffer_size), deleter);
std::unique_ptr<char[], void (*)(char*)> recv_buffer(
(char*)std::malloc(max_buffer_size), deleter);

// copy offset and size info to front of buffer
std::memcpy(buffer.get(), info.data(), info.size() * sizeof(vec_info));

// copy each vector to each location in buffer
impl::for_each_in_tuple(
[&](std::size_t i, auto&& vec) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
std::copy(vec.begin(), vec.end(), (T*)(buffer.get() + info[i].offset));
},
arg_tuple);


const auto my_rank = distributed.id();
const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1;
const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1;

// exchange buffer in ring pattern and apply function at each step
for (int step = 0; step < distributed.size() - 1; ++step) {
// always expect to recieve the max size but send actual size. MPI_recv only expects a max
// size, not the actual size.
const auto current_info = (const vec_info*)buffer.get();

auto request = distributed.send_recv_nonblocking(max_buffer_size,
recv_buffer.get(),
right_rank,
current_info[info.size() - 1].offset + current_info[info.size() - 1].size,
buffer.get(),
left_rank,
0);

// update ranges
impl::for_each_in_tuple_pair(
[&](std::size_t i, auto&& vec, auto&& r) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
r = util::range<T*>((T*)(buffer.get() + current_info[i].offset),
(T*)(buffer.get() + current_info[i].offset + current_info[i].size));
},
arg_tuple,
ranges);

// call provided function with ranges pointing to current buffer
std::apply(func, ranges);

request.finalize();
buffer.swap(recv_buffer);
}

// final step does not require any exchange
const auto current_info = (const vec_info*)buffer.get();
impl::for_each_in_tuple_pair(
[&](std::size_t i, auto&& vec, auto&& r) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
r = util::range<T*>((T*)(buffer.get() + current_info[i].offset),
(T*)(buffer.get() + current_info[i].offset + current_info[i].size));
},
arg_tuple,
ranges);

// call provided function with ranges pointing to current buffer
std::apply(func, ranges);
}

} // namespace arb
Loading