Skip to content

Commit

Permalink
absorb mark, de-switch updates
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenhater committed Dec 4, 2024
1 parent 8a46b45 commit 6268fdb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 61 deletions.
19 changes: 7 additions & 12 deletions arbor/backends/shared_state_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,16 @@ struct shared_state_base {
return nullptr;
}

void mark_events() {
auto d = static_cast<D*>(this);
auto& streams = d->streams;
for (auto& stream: streams) stream.second.mark();
}

void deliver_events(mechanism& m) {
auto d = static_cast<D*>(this);
auto& streams = d->streams;
if (auto it = streams.find(m.mechanism_id()); it != streams.end()) {
if (auto& deliverable_events = it->second; !deliverable_events.empty()) {
auto state = deliverable_events.marked_events();
m.deliver_events(state);
}
}
auto id = m.mechanism_id();
if (!streams.count(id)) return;
auto& stream = streams.at(id);
stream.mark();
if (stream.empty()) return;
auto marked = stream.marked_events();
m.deliver_events(marked);
}

void reset_thresholds() {
Expand Down
84 changes: 36 additions & 48 deletions arbor/fvm_lowered_cell_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ struct fvm_lowered_cell_impl: public fvm_lowered_cell {
std::vector<mechanism_ptr> revpot_mechanisms_;
std::vector<mechanism_ptr> voltage_mechanisms_;

// track synapses, prng users, post eventers
std::vector<mechanism*> has_targets_;
std::vector<mechanism*> has_prng_;
std::vector<mechanism*> has_post_event_;

// Handles for accessing event targets.
std::vector<target_handle> target_handles_;
// Lookup table for target ids -> local target handle indices.
Expand All @@ -87,9 +92,6 @@ struct fvm_lowered_cell_impl: public fvm_lowered_cell {
// random number generator seed value
arb_seed_type seed_;

// Flag indicating that at least one of the mechanisms implements the post_events procedure
bool post_events_ = false;

void update_ion_state();

// Throw if absolute value of membrane voltage exceeds bounds.
Expand Down Expand Up @@ -180,34 +182,22 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate(const timestep_
arb_assert(state_->time == ts.t_begin());

// Update integration step time information visible to mechanisms.
for (auto& m: mechanisms_) {
m->set_dt(state_->dt);
}
for (auto& m: revpot_mechanisms_) {
m->set_dt(state_->dt);
}
for (auto& m: voltage_mechanisms_) {
m->set_dt(state_->dt);
}
for (auto& m: mechanisms_) m->set_dt(state_->dt);
for (auto& m: revpot_mechanisms_) m->set_dt(state_->dt);
for (auto& m: voltage_mechanisms_) m->set_dt(state_->dt);

// Update any required reversal potentials based on ionic concentrations
for (auto& m: revpot_mechanisms_) {
m->update_current();
}
for (auto& m: revpot_mechanisms_) m->update_current();

PE(advance:integrate:current:zero);
state_->zero_currents();
PL();

// Deliver events and accumulate mechanism current contributions.
// apply relevant events and drop them afterwards
for (auto& m: has_targets_) state_->deliver_events(*m);

// Mark all events due before (but not including) the end of this time step (state_->time_to) for delivery
state_->mark_events();
for (auto& m: mechanisms_) {
// apply the events and drop them afterwards
state_->deliver_events(*m);
m->update_current();
}
// accumulate mechanism current contributions.
for (auto& m: mechanisms_) m->update_current();

// Add stimulus current contributions.
// NOTE: performed after dt, time_to calculation, in case we want to
Expand All @@ -226,11 +216,10 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate(const timestep_
state_->integrate_cable_state();
PL();

// do PRNG update, where needed
for (auto& m: has_prng_) state_->update_prng_state(*m);
// Integrate mechanism state for density
for (auto& m: mechanisms_) {
state_->update_prng_state(*m);
m->update_state();
}
for (auto& m: mechanisms_) m->update_state();

// Update ion concentrations.
PE(advance:integrate:ionupdate);
Expand All @@ -239,25 +228,16 @@ fvm_integration_result fvm_lowered_cell_impl<Backend>::integrate(const timestep_

// voltage mechs run now; after the cable_solver, but before the
// threshold test
for (auto& m: voltage_mechanisms_) {
m->update_current();
}
for (auto& m: voltage_mechanisms_) {
state_->update_prng_state(*m);
m->update_state();
}
for (auto& m: voltage_mechanisms_) m->update_current();
for (auto& m: voltage_mechanisms_) m->update_state();

// Update time and test for spike threshold crossings.
PE(advance:integrate:threshold);
state_->test_thresholds();
PL();

PE(advance:integrate:post);
if (post_events_) {
for (auto& m: mechanisms_) {
m->post_event();
}
}
for (auto& m: has_post_event_) m->post_event();
PL();

// Advance epoch
Expand Down Expand Up @@ -428,15 +408,15 @@ fvm_lowered_cell_impl<Backend>::initialize(const std::vector<cell_gid_type>& gid
fvm_mechanism_data mech_data = fvm_build_mechanism_data(global_props, cells, gids, gj_conns, D, context_);

// Fill src_to_spike and cv_to_cell vectors only if mechanisms with post_events implemented are present.
post_events_ = mech_data.post_events;
bool post_events = mech_data.post_events;
auto max_detector = 0;
if (post_events_) {
if (post_events) {
auto it = util::max_element_by(fvm_info.num_sources, [](auto elem) {return util::second(elem);});
max_detector = it->second;
}
std::vector<arb_index_type> src_to_spike, cv_to_cell;

if (post_events_) {
if (post_events) {
for (auto cell_idx: make_span(ncell)) {
for (auto lid: make_span(fvm_info.num_sources[gids[cell_idx]])) {
src_to_spike.push_back(cell_idx * max_detector + lid);
Expand Down Expand Up @@ -474,7 +454,6 @@ fvm_lowered_cell_impl<Backend>::initialize(const std::vector<cell_gid_type>& gid
// Keep track of mechanisms by name for probe lookup.
std::unordered_map<std::string, mechanism*> mechptr_by_name;
target_handles_.resize(mech_data.n_target);

unsigned mech_id = 0;
for (const auto& [name, config]: mech_data.mechanisms) {
mechanism_layout layout;
Expand All @@ -495,7 +474,6 @@ fvm_lowered_cell_impl<Backend>::initialize(const std::vector<cell_gid_type>& gid
case arb_mechanism_kind_point:
// Point mechanism contributions are in [nA]; CV area A in [µm^2].
// F = 1/A * [nA/µm²] / [A/m²] = 1000/A.

layout.gid.resize(config.cv.size());
layout.idx.resize(layout.gid.size());
for (auto i: count_along(config.cv)) {
Expand All @@ -506,7 +484,6 @@ fvm_lowered_cell_impl<Backend>::initialize(const std::vector<cell_gid_type>& gid
layout.idx[i] = i - idx_offset;

if (config.target.empty()) continue;

target_handle handle(mech_id, i);
if (config.multiplicity.empty()) {
target_handles_[config.target[i]] = handle;
Expand Down Expand Up @@ -550,11 +527,21 @@ fvm_lowered_cell_impl<Backend>::initialize(const std::vector<cell_gid_type>& gid
auto [mech, over] = mech_instance(name);
state_->instantiate(*mech, mech_id, over, layout, config.param_values);
mechptr_by_name[name] = mech.get();
if (mech->mech_.n_random_variables) has_prng_.push_back(mech.get());
if (mech->mech_.has_post_events) has_post_event_.push_back(mech.get());
if (fvm_info.num_targets_per_mech_id[mech_id]) has_targets_.push_back(mech.get());
++mech_id;


switch (config.kind) {
case arb_mechanism_kind_gap_junction:
case arb_mechanism_kind_point:
case arb_mechanism_kind_gap_junction: {
mechanisms_.emplace_back(mech.release());
break;
}
case arb_mechanism_kind_point: {
mechanisms_.emplace_back(mech.release());
break;
}
case arb_mechanism_kind_density: {
mechanisms_.emplace_back(mech.release());
break;
Expand All @@ -567,8 +554,9 @@ fvm_lowered_cell_impl<Backend>::initialize(const std::vector<cell_gid_type>& gid
voltage_mechanisms_.emplace_back(mech.release());
break;
}
default:;
default: {
throw invalid_mechanism_kind(config.kind);
}
}
}

Expand Down
1 change: 0 additions & 1 deletion test/unit/test_synapses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ TEST(synapses, syn_basic_state) {
std::vector<size_t> divs{0, handles.size()};
auto ctx = arb::make_context();
state.begin_epoch(lanes, {}, dts, handles, divs, ctx->thread_pool);
state.mark_events();

state.deliver_events(*expsyn);
state.deliver_events(*exp2syn);
Expand Down

0 comments on commit 6268fdb

Please sign in to comment.