Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into doc/fix-math-and-wa…
Browse files Browse the repository at this point in the history
…rnings
  • Loading branch information
thorstenhater committed Dec 8, 2023
2 parents 1bd8362 + 5a965a3 commit 327c56d
Show file tree
Hide file tree
Showing 40 changed files with 1,747 additions and 418 deletions.
5 changes: 2 additions & 3 deletions arbor/benchmark_cell_group.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <chrono>
#include <exception>

#include <arbor/arbexcept.hpp>
#include <arbor/benchmark_cell.hpp>
Expand Down Expand Up @@ -40,8 +39,8 @@ benchmark_cell_group::benchmark_cell_group(const std::vector<cell_gid_type>& gid
for (const auto& c: cells_) {
cg_sources.add_cell();
cg_targets.add_cell();
cg_sources.add_label(c.source, {0, 1});
cg_targets.add_label(c.target, {0, 1});
cg_sources.add_label(hash_value(c.source), {0, 1});
cg_targets.add_label(hash_value(c.target), {0, 1});
}

benchmark_cell_group::reset();
Expand Down
18 changes: 12 additions & 6 deletions arbor/cable_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct cable_cell_impl {
decor decorations;

// The placeable label to lid_range map
dynamic_typed_map<constant_type<std::unordered_multimap<cell_tag_type, lid_range>>::type> labeled_lid_ranges;
dynamic_typed_map<constant_type<std::unordered_multimap<hash_type, lid_range>>::type> labeled_lid_ranges;

cable_cell_impl(const arb::morphology& m, const label_dict& labels, const decor& decorations):
provider(m, labels),
Expand Down Expand Up @@ -120,7 +120,7 @@ struct cable_cell_impl {
}

template <typename Item>
void place(const locset& ls, const Item& item, const cell_tag_type& label) {
void place(const locset& ls, const Item& item, const hash_type& label) {
auto& mm = get_location_map(item);
cell_lid_type& lid = placed_count.get<Item>();
cell_lid_type first = lid;
Expand Down Expand Up @@ -226,7 +226,8 @@ void cable_cell_impl::init(const decor& d) {
for (const auto& p: d.placements()) {
auto& where = std::get<0>(p);
auto& label = std::get<2>(p);
std::visit([this, &where, &label] (auto&& what) {return this->place(where, what, label);}, std::get<1>(p));
std::visit([this, &where, &label] (auto&& what) {return this->place(where, what, label); },
std::get<1>(p));
}
}

Expand Down Expand Up @@ -280,16 +281,21 @@ const cable_cell_parameter_set& cable_cell::default_parameters() const {
return impl_->decorations.defaults();
}

const std::unordered_multimap<cell_tag_type, lid_range>& cable_cell::detector_ranges() const {
const cable_cell::lid_range_map& cable_cell::detector_ranges() const {
return impl_->labeled_lid_ranges.get<threshold_detector>();
}

const std::unordered_multimap<cell_tag_type, lid_range>& cable_cell::synapse_ranges() const {
const cable_cell::lid_range_map& cable_cell::synapse_ranges() const {
return impl_->labeled_lid_ranges.get<synapse>();
}

const std::unordered_multimap<cell_tag_type, lid_range>& cable_cell::junction_ranges() const {
const cable_cell::lid_range_map& cable_cell::junction_ranges() const {
return impl_->labeled_lid_ranges.get<junction>();
}

cell_tag_type decor::tag_of(hash_type hash) const {
if (!hashes_.count(hash)) throw arbor_internal_error{util::pprintf("Unknown hash for {}.", std::to_string(hash))};
return hashes_.at(hash);
}

} // namespace arb
12 changes: 8 additions & 4 deletions arbor/cable_cell_param.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
#include <cfloat>
#include <cmath>
#include <memory>
#include <numeric>
#include <vector>
#include <variant>
#include <tuple>
Expand All @@ -10,7 +7,9 @@
#include <arbor/cable_cell_param.hpp>
#include <arbor/s_expr.hpp>

#include <arbor/util/hash_def.hpp>
#include "util/maputil.hpp"
#include "util/strprintf.hpp"

namespace arb {

Expand Down Expand Up @@ -120,7 +119,12 @@ decor& decor::paint(region where, paintable what) {
}

decor& decor::place(locset where, placeable what, cell_tag_type label) {
placements_.emplace_back(std::move(where), std::move(what), std::move(label));
auto hash = hash_value(label);
if (hashes_.count(hash) && hashes_.at(hash) != label) {
throw arbor_internal_error{util::strprintf("Hash collision {} ./. {}", label, hashes_.at(hash))};
}
placements_.emplace_back(std::move(where), std::move(what), hash);
hashes_.emplace(hash, label);
return *this;
}

Expand Down
145 changes: 84 additions & 61 deletions arbor/communication/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void communicator::update_connections(const connectivity& rec,
PE(init:communicator:update:clear);
// Forget all lingering information
connections_.clear();
ext_connections_.clear();
connection_part_.clear();
index_divisions_.clear();
PL();
Expand Down Expand Up @@ -123,8 +124,8 @@ void communicator::update_connections(const connectivity& rec,
// to do this in place.
// NOTE: The connections are partitioned by the domain of their source gid.
PE(init:communicator:update:connections);
connections_.resize(n_cons);
ext_connections_.resize(n_ext_cons);
std::vector<connection> connections(n_cons);
std::vector<connection> ext_connections(n_ext_cons);
auto offsets = connection_part_; // Copy, as we use this as the list of current target indices to write into
std::size_t ext = 0;
auto src_domain = src_domains.begin();
Expand All @@ -140,15 +141,15 @@ void communicator::update_connections(const connectivity& rec,
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
auto offset = offsets[*src_domain]++;
++src_domain;
connections_[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
}
for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) {
const auto& conn = gid_ext_connections[cidx];
auto src = global_cell_of(conn.source);
auto src_gid = conn.source.rid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
ext_connections_[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
++ext;
}
}
Expand All @@ -168,9 +169,14 @@ void communicator::update_connections(const connectivity& rec,
const auto& cp = connection_part_;
threading::parallel_for::apply(0, num_domains_, thread_pool_.get(),
[&](cell_size_type i) {
util::sort(util::subrange_view(connections_, cp[i], cp[i+1]));
util::sort(util::subrange_view(connections, cp[i], cp[i+1]));
});
std::sort(ext_connections_.begin(), ext_connections_.end());
std::sort(ext_connections.begin(), ext_connections.end());
PL();

PE(init:communicator:update:destructure_connections);
connections_.make(connections);
ext_connections_.make(ext_connections);
PL();
}

Expand All @@ -181,12 +187,12 @@ std::pair<cell_size_type, cell_size_type> communicator::group_queue_range(cell_s

time_type communicator::min_delay() {
time_type res = std::numeric_limits<time_type>::max();
res = std::accumulate(connections_.begin(), connections_.end(),
res,
[](auto&& acc, auto&& el) { return std::min(acc, time_type(el.delay)); });
res = std::accumulate(ext_connections_.begin(), ext_connections_.end(),
res,
[](auto&& acc, auto&& el) { return std::min(acc, time_type(el.delay)); });
res = std::accumulate(connections_.delays.begin(), connections_.delays.end(),
res,
[](auto&& acc, time_type del) { return std::min(acc, del); });
res = std::accumulate(ext_connections_.delays.begin(), ext_connections_.delays.end(),
res,
[](auto&& acc, time_type del) { return std::min(acc, del); });
res = distributed_->min(res);
return res;
}
Expand Down Expand Up @@ -228,19 +234,38 @@ void communicator::set_remote_spike_filter(const spike_predicate& p) { remote_sp
void communicator::remote_ctrl_send_continue(const epoch& e) { distributed_->remote_ctrl_send_continue(e); }
void communicator::remote_ctrl_send_done() { distributed_->remote_ctrl_send_done(); }

// Given
// * a set of connections and an index into the set
// * a range of spikes
// * an output queue,
// append events for that sub-range of spikes to the
// queue that has the same source as the connection
// at index.
template<typename It>
void enqueue_from_source(const communicator::connection_list& cons,
const size_t idx,
It& spk,
const It end,
std::vector<pse_vector>& out) {
// const refs to connection.
auto src = cons.srcs[idx];
auto dst = cons.dests[idx];
auto del = cons.delays[idx];
auto wgt = cons.weights[idx];
auto dom = cons.idx_on_domain[idx];
auto& que = out[dom];
for (; spk != end && spk->source == src; ++spk) {
que.emplace_back(dst, spk->time + del, wgt);
}
}

// Internal helper to append to the event queues
template<typename S, typename C>
void append_events_from_domain(C cons,
S spks,
template<typename S>
void append_events_from_domain(const communicator::connection_list& cons, size_t cn, const size_t ce,
const S& spks,
std::vector<pse_vector>& queues) {
// Predicate for partitioning
struct spike_pred {
bool operator()(const spike& spk, const cell_member_type& src) { return spk.source < src; }
bool operator()(const cell_member_type& src, const spike& spk) { return src < spk.source; }
};

auto sp = spks.begin(), se = spks.end();
auto cn = cons.begin(), ce = cons.end();
if (se == sp) return;
// We have a choice of whether to walk spikes or connections:
// i.e., we can iterate over the spikes, and for each spike search
// the for connections that have the same source; or alternatively
Expand All @@ -251,64 +276,62 @@ void append_events_from_domain(C cons,
// complexity of order max(S log(C), C log(S)), where S is the
// number of spikes, and C is the number of connections.
if (cons.size() < spks.size()) {
while (cn != ce && sp != se) {
auto sources = std::equal_range(sp, se, cn->source, spike_pred());
for (auto s: util::make_range(sources)) {
queues[cn->index_on_domain].push_back(make_event(*cn, s));
}
sp = sources.first;
++cn;
for (; sp != se && cn < ce; ++cn) {
// sp is now the beginning of a range of spikes from the same
// source.
sp = std::lower_bound(sp, se,
cons.srcs[cn],
[](const auto& spk, const auto& src) { return spk.source < src; });
// now, sp is at the end of the equal source range.
enqueue_from_source(cons, cn, sp, se, queues);
}
}
else {
while (cn != ce && sp != se) {
auto targets = std::equal_range(cn, ce, sp->source);
for (auto c: util::make_range(targets)) {
queues[c.index_on_domain].push_back(make_event(c, *sp));
while (sp != se) {
auto beg = sp;
auto src = beg->source;
// Here, `cn` is the index of the first connection whose source
// is larger or equal to the spike's source. It may be `ce` if
// all elements compare < to spk.source.
cn = std::lower_bound(cons.srcs.begin() + cn,
cons.srcs.begin() + ce,
src)
- cons.srcs.begin();
for (; cn < ce && cons.srcs[cn] == src; ++cn) {
// Reset the spike iterator as we walk the same sub-range
// for each connection with the same source.
sp = beg;
// If we ever get multiple spikes from the same source, treat
// them all. This is mostly rare.
enqueue_from_source(cons, cn, sp, se, queues);
}
cn = targets.first;
++sp;
while (sp != se && sp->source == src) ++sp;
}
}
}

void communicator::make_event_queues(
const gathered_vector<spike>& global_spikes,
std::vector<pse_vector>& queues,
const std::vector<spike>& external_spikes) {
void communicator::make_event_queues(communicator::spikes& spikes,
std::vector<pse_vector>& queues) {
arb_assert(queues.size()==num_local_cells_);
const auto& sp = global_spikes.partition();
const auto& sp = spikes.from_local.partition();
const auto& cp = connection_part_;
for (auto dom: util::make_span(num_domains_)) {
append_events_from_domain(util::subrange_view(connections_, cp[dom], cp[dom+1]),
util::subrange_view(global_spikes.values(), sp[dom], sp[dom+1]),
append_events_from_domain(connections_, cp[dom], cp[dom+1],
util::subrange_view(spikes.from_local.values(), sp[dom], sp[dom+1]),
queues);
}
num_local_events_ = util::sum_by(queues, [](const auto& q) {return q.size();}, num_local_events_);
// Now that all local spikes have been processed; consume the remote events coming in.
// - turn all gids into externals
auto spikes = external_spikes;
std::for_each(spikes.begin(),
spikes.end(),
std::for_each(spikes.from_remote.begin(), spikes.from_remote.end(),
[](auto& s) { s.source = global_cell_of(s.source); });
append_events_from_domain(ext_connections_, spikes, queues);
append_events_from_domain(ext_connections_, 0, ext_connections_.size(), spikes.from_remote, queues);
}

std::uint64_t communicator::num_spikes() const {
return num_spikes_;
}

void communicator::set_num_spikes(std::uint64_t n) {
num_spikes_ = n;
}

cell_size_type communicator::num_local_cells() const {
return num_local_cells_;
}

const std::vector<connection>& communicator::connections() const {
return connections_;
}
std::uint64_t communicator::num_spikes() const { return num_spikes_; }
void communicator::set_num_spikes(std::uint64_t n) { num_spikes_ = n; }
cell_size_type communicator::num_local_cells() const { return num_local_cells_; }
const communicator::connection_list& communicator::connections() const { return connections_; }

void communicator::reset() {
num_spikes_ = 0;
Expand Down
Loading

0 comments on commit 327c56d

Please sign in to comment.