Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve partition_load_balance #2206

Merged
merged 15 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions arbor/communication/dry_run_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,6 @@ struct dry_run_context_impl {
return gathered_vector<cell_gid_type>(std::move(gathered_gids), std::move(partition));
}

std::vector<std::vector<cell_gid_type>>
gather_gj_connections(const std::vector<std::vector<cell_gid_type>> & local_connections) const {
auto local_size = local_connections.size();
std::vector<std::vector<cell_gid_type>> global_connections;
global_connections.reserve(local_size*num_ranks_);

for (unsigned i = 0; i < num_ranks_; i++) {
util::append(global_connections, local_connections);
}

for (unsigned i = 0; i < num_ranks_; i++) {
for (unsigned j = i*local_size; j < (i+1)*local_size; j++){
for (auto& conn_gid: global_connections[j]) {
conn_gid += num_cells_per_tile_*i;
}
}
}
return global_connections;
}

cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const {
cell_label_range global_ranges;
for (unsigned i = 0; i < num_ranks_; i++) {
Expand Down
10 changes: 0 additions & 10 deletions arbor/communication/mpi_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,6 @@ struct mpi_context_impl {
return mpi::gather_all_with_partition(local_gids, comm_);
}

std::vector<std::vector<cell_gid_type>>
gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const {
return mpi::gather_all(local_connections, comm_);
}

cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const {
std::vector<cell_size_type> sizes;
std::vector<cell_tag_type> labels;
Expand Down Expand Up @@ -141,11 +136,6 @@ struct remote_context_impl {
gathered_vector<cell_gid_type>
gather_gids(const std::vector<cell_gid_type>& local_gids) const { return mpi_.gather_gids(local_gids); }

std::vector<std::vector<cell_gid_type>>
gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const {
return mpi_.gather_gj_connections(local_connections);
}

cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const {
return mpi_.gather_cell_label_range(local_ranges);
}
Expand Down
14 changes: 0 additions & 14 deletions arbor/distributed_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ class distributed_context {
return impl_->gather_gids(local_gids);
}

gj_connection_vector gather_gj_connections(const gj_connection_vector& local_connections) const {
return impl_->gather_gj_connections(local_connections);
}

cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const {
return impl_->gather_cell_label_range(local_ranges);
}
Expand Down Expand Up @@ -117,8 +113,6 @@ class distributed_context {
remote_gather_spikes(const spike_vector& local_spikes) const = 0;
virtual gathered_vector<cell_gid_type>
gather_gids(const gid_vector& local_gids) const = 0;
virtual gj_connection_vector
gather_gj_connections(const gj_connection_vector& local_connections) const = 0;
virtual cell_label_range
gather_cell_label_range(const cell_label_range& local_ranges) const = 0;
virtual cell_labels_and_gids
Expand Down Expand Up @@ -154,10 +148,6 @@ class distributed_context {
gather_gids(const gid_vector& local_gids) const override {
return wrapped.gather_gids(local_gids);
}
std::vector<std::vector<cell_gid_type>>
gather_gj_connections(const gj_connection_vector& local_connections) const override {
return wrapped.gather_gj_connections(local_connections);
}
cell_label_range
gather_cell_label_range(const cell_label_range& local_ranges) const override {
return wrapped.gather_cell_label_range(local_ranges);
Expand Down Expand Up @@ -217,10 +207,6 @@ struct local_context {
}
void remote_ctrl_send_continue(const epoch&) const {}
void remote_ctrl_send_done() const {}
std::vector<std::vector<cell_gid_type>>
gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const {
return local_connections;
}
cell_label_range
gather_cell_label_range(const cell_label_range& local_ranges) const {
return local_ranges;
Expand Down
36 changes: 10 additions & 26 deletions arbor/domain_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
#include "util/span.hpp"

namespace arb {
domain_decomposition::domain_decomposition(
const recipe& rec,
context ctx,
const std::vector<group_description>& groups)
{
domain_decomposition::domain_decomposition(const recipe& rec,
context ctx,
const std::vector<group_description>& groups) {
struct partition_gid_domain {
partition_gid_domain(const gathered_vector<cell_gid_type>& divs, unsigned domains) {
auto rank_part = util::partition_view(divs.partition());
Expand All @@ -27,9 +25,7 @@ domain_decomposition::domain_decomposition(
}
}
}
int operator()(cell_gid_type gid) const {
return gid_map.at(gid);
}
int operator()(cell_gid_type gid) const { return gid_map.at(gid); }
std::unordered_map<cell_gid_type, int> gid_map;
};

Expand All @@ -41,39 +37,27 @@ domain_decomposition::domain_decomposition(

std::vector<cell_gid_type> local_gids;
for (const auto& g: groups) {
if (g.backend == backend_kind::gpu && !has_gpu) {
throw invalid_backend(domain_id);
}
if (g.backend == backend_kind::gpu && g.kind != cell_kind::cable) {
throw incompatible_backend(domain_id, g.kind);
}
if (g.backend == backend_kind::gpu && !has_gpu) throw invalid_backend(domain_id);
if (g.backend == backend_kind::gpu && g.kind != cell_kind::cable) throw incompatible_backend(domain_id, g.kind);

std::unordered_set<cell_gid_type> gid_set(g.gids.begin(), g.gids.end());
for (const auto& gid: g.gids) {
if (gid >= num_global_cells) {
throw out_of_bounds(gid, num_global_cells);
}
if (gid >= num_global_cells) throw out_of_bounds(gid, num_global_cells);
for (const auto& gj: rec.gap_junctions_on(gid)) {
if (!gid_set.count(gj.peer.gid)) {
throw invalid_gj_cell_group(gid, gj.peer.gid);
}
if (!gid_set.count(gj.peer.gid)) throw invalid_gj_cell_group(gid, gj.peer.gid);
}
}
local_gids.insert(local_gids.end(), g.gids.begin(), g.gids.end());
}
cell_size_type num_local_cells = local_gids.size();

auto global_gids = dist->gather_gids(local_gids);
if (global_gids.size() != num_global_cells) {
throw invalid_sum_local_cells(global_gids.size(), num_global_cells);
}
if (global_gids.size() != num_global_cells) throw invalid_sum_local_cells(global_gids.size(), num_global_cells);

auto global_gid_vals = global_gids.values();
util::sort(global_gid_vals);
for (unsigned i = 1; i < global_gid_vals.size(); ++i) {
if (global_gid_vals[i] == global_gid_vals[i-1]) {
throw duplicate_gid(global_gid_vals[i]);
}
if (global_gid_vals[i] == global_gid_vals[i-1]) throw duplicate_gid(global_gid_vals[i]);
}

num_domains_ = num_domains;
Expand Down
1 change: 0 additions & 1 deletion arbor/execution_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ struct ARB_ARBOR_API execution_context {

template <typename Comm>
execution_context(const proc_allocation& resources, Comm comm, Comm remote);

};

} // namespace arb
6 changes: 4 additions & 2 deletions arbor/include/arbor/common_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,12 @@ using probe_tag = int;
using sample_size_type = std::int32_t;

// Enumeration for execution back-end targets, as specified in domain decompositions.

// NOTE(important): Given in order of priority, ie we will attempt schedule gpu before
// MC groups, for reasons of effiency. Ugly, but as we do not have more
// backends, this is OK for now.
enum class backend_kind {
gpu, // Use gpu back-end when supported by cell_group implementation.
multicore, // Use multicore back-end for all computation.
gpu // Use gpu back-end when supported by cell_group implementation.
};
boeschf marked this conversation as resolved.
Show resolved Hide resolved

// Enumeration used to indentify the cell type/kind, used by the model to
Expand Down
Loading