Skip to content

Commit

Permalink
Merge pull request #824 from PowerGridModel/feature/refactor-optional-id
Browse files Browse the repository at this point in the history
Feature / use span instead of reference wrapper for cache sequence
  • Loading branch information
mgovers authored Nov 13, 2024
2 parents 004421a + 8b39eda commit 1624ba3
Showing 1 changed file with 43 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,9 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis

// update all components
template <cache_type_c CacheType>
void
update_component(ConstDataset const& update_data, Idx pos,
std::array<std::reference_wrapper<std::vector<Idx2D> const>, n_types> const& sequence_idx_map) {
void update_component(ConstDataset const& update_data, Idx pos, SequenceIdxView const& sequence_idx_map) {
run_functor_with_all_types_return_void([this, pos, &update_data, &sequence_idx_map]<typename CT>() {
this->update_component<CT, CacheType>(update_data, pos,
std::get<index_of_component<CT>>(sequence_idx_map).get());
this->update_component<CT, CacheType>(update_data, pos, std::get<index_of_component<CT>>(sequence_idx_map));
});
}
template <cache_type_c CacheType>
Expand Down Expand Up @@ -432,36 +429,28 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
return get_sequence(buffer_span);
}
SequenceIdx get_sequence_idx_map(ConstDataset const& update_data, Idx scenario_idx,
ComponentFlags const& to_store) const {
// TODO: (jguo) this function could be encapsulated in UpdateCompIndependence in update.hpp
return run_functor_with_all_types_return_array([this, scenario_idx, &update_data, &to_store]<typename CT>() {
if (!to_store[index_of_component<CT>]) {
return std::vector<Idx2D>{};
}
auto const independence = check_components_independence<CT>(update_data);
validate_update_data_independence(independence);
return get_component_sequence<CT>(update_data, scenario_idx, independence);
});
}
// get sequence idx map of an entire batch for fast caching of component sequences
SequenceIdx get_sequence_idx_map(ConstDataset const& update_data, ComponentFlags const& to_store) const {
ComponentFlags const& components_to_store) const {
// TODO: (jguo) this function could be encapsulated in UpdateCompIndependence in update.hpp
return run_functor_with_all_types_return_array([this, &update_data, &to_store]<typename CT>() {
if (!to_store[index_of_component<CT>]) {
return std::vector<Idx2D>{};
}
auto const independence = check_components_independence<CT>(update_data);
validate_update_data_independence(independence);
return get_component_sequence<CT>(update_data, 0, independence);
});
return run_functor_with_all_types_return_array(
[this, scenario_idx, &update_data, &components_to_store]<typename CT>() {
if (!std::get<index_of_component<CT>>(components_to_store)) {
return std::vector<Idx2D>{};
}
auto const independence = check_components_independence<CT>(update_data);
validate_update_data_independence(independence);
return get_component_sequence<CT>(update_data, scenario_idx, independence);
});
}
// Get sequence idx map of an entire batch for fast caching of component sequences.
// The sequence idx map of the batch is the same as that of the first scenario in the batch (assuming homogeneity)
// This is the entry point for permanent updates.
SequenceIdx get_sequence_idx_map(ConstDataset const& update_data) const {
constexpr ComponentFlags all_true = [] {
ComponentFlags result{};
std::ranges::fill(result, true);
return result;
}();
return get_sequence_idx_map(update_data, all_true);
return get_sequence_idx_map(update_data, 0, all_true);
}

private:
Expand Down Expand Up @@ -621,9 +610,10 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
// const ref of current instance
MainModelImpl const& base_model = *this;

// cache component update order if possible
// cache component update order where possible.
// the order for a cacheable (independent) component by definition is the same across all scenarios
auto const is_independent = is_update_independent(update_data);
all_scenarios_sequence = get_sequence_idx_map(update_data, is_independent);
all_scenarios_sequence = get_sequence_idx_map(update_data, 0, is_independent);

return [&base_model, &exceptions, &infos, &calculation_fn, &result_data, &update_data,
&all_scenarios_sequence = std::as_const(all_scenarios_sequence),
Expand All @@ -640,15 +630,8 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
auto model = copy_model_functor(start);

SequenceIdx current_scenario_sequence_cache = SequenceIdx{};
std::array<std::reference_wrapper<std::vector<Idx2D> const>, n_types> const current_scenario_sequence =
run_functor_with_all_types_return_array(
[&is_independent, &all_scenarios_sequence, &current_scenario_sequence_cache]<typename CT>() {
constexpr auto comp_idx = index_of_component<CT>;
return is_independent[comp_idx] ? std::cref(all_scenarios_sequence[comp_idx])
: std::cref(current_scenario_sequence_cache[comp_idx]);
});
auto [setup, winddown] = scenario_update_restore(model, update_data, current_scenario_sequence,
current_scenario_sequence_cache, is_independent, infos);
auto [setup, winddown] = scenario_update_restore(model, update_data, is_independent, all_scenarios_sequence,
current_scenario_sequence_cache, infos);

auto calculate_scenario = MainModelImpl::call_with<Idx>(
[&model, &calculation_fn, &result_data, &infos](Idx scenario_idx) {
Expand Down Expand Up @@ -720,28 +703,41 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
};
}

static auto scenario_update_restore(
MainModelImpl& model, ConstDataset const& update_data,
std::array<std::reference_wrapper<std::vector<Idx2D> const>, n_types> const& scenario_sequence,
SequenceIdx& current_scenario_sequence_cache, ComponentFlags const& is_independent,
std::vector<CalculationInfo>& infos) noexcept {
static auto scenario_update_restore(MainModelImpl& model, ConstDataset const& update_data,
ComponentFlags const& is_independent, SequenceIdx const& all_scenario_sequence,
SequenceIdx& current_scenario_sequence_cache,
std::vector<CalculationInfo>& infos) noexcept {
auto do_update_cache = [&is_independent] {
ComponentFlags result;
std::ranges::transform(is_independent, result.begin(), std::logical_not<>{});
return result;
}();

auto const scenario_sequence = [&all_scenario_sequence, &current_scenario_sequence_cache,
&is_independent]() -> SequenceIdxView {
return run_functor_with_all_types_return_array(
[&all_scenario_sequence, &current_scenario_sequence_cache, &is_independent]<typename CT>() {
constexpr auto comp_idx = index_of_component<CT>;
if (std::get<comp_idx>(is_independent)) {
return std::span<Idx2D const>{std::get<comp_idx>(all_scenario_sequence)};
}
return std::span<Idx2D const>{std::get<comp_idx>(current_scenario_sequence_cache)};
});
};

return std::make_pair(
[&model, &update_data, &scenario_sequence, &current_scenario_sequence_cache,
[&model, &update_data, scenario_sequence, &current_scenario_sequence_cache,
do_update_cache_ = std::move(do_update_cache), &infos](Idx scenario_idx) {
Timer const t_update_model(infos[scenario_idx], 1200, "Update model");
current_scenario_sequence_cache =
model.get_sequence_idx_map(update_data, scenario_idx, do_update_cache_);
model.template update_component<cached_update_t>(update_data, scenario_idx, scenario_sequence);

model.template update_component<cached_update_t>(update_data, scenario_idx, scenario_sequence());
},
[&model, &scenario_sequence, &current_scenario_sequence_cache, &infos](Idx scenario_idx) {
[&model, scenario_sequence, &current_scenario_sequence_cache, &infos](Idx scenario_idx) {
Timer const t_update_model(infos[scenario_idx], 1201, "Restore model");
model.restore_components(scenario_sequence);

model.restore_components(scenario_sequence());
std::ranges::for_each(current_scenario_sequence_cache,
[](auto& comp_seq_idx) { comp_seq_idx.clear(); });
});
Expand Down

0 comments on commit 1624ba3

Please sign in to comment.